diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index dadf27e21f..a080f52d5f 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -25,8 +25,6 @@ #define TVM_RELAX_TRANSFORM_H_ #include -#include -#include #include namespace tvm { @@ -119,14 +117,6 @@ TVM_DLL Pass CanonicalizeBindings(); */ TVM_DLL Pass Normalize(); -/*! - * \brief Apply the best schedule from tuning database. - * - * \return The Pass. - */ -TVM_DLL Pass MetaScheduleApplyHistoryBest(const tvm::meta_schedule::Database& database, - Target target); - /*! * \brief Bind params of function of the module to constant tensors. * diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 887981ccff..dacfc361a6 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -38,12 +38,17 @@ namespace ir { */ class IRModuleFrameNode : public IRBuilderFrameNode { public: - Array global_vars; - Array functions; + /*! \brief A map from string names to global variables that ensures global uniqueness. */ + Map global_var_map; + /*! + * \brief A map from GlobalVar to all global functions. + * \note Only defined functions are in the map, while declared functions are not included. + */ + Map functions; void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_vars); + v->Visit("global_vars", &global_var_map); v->Visit("functions", &functions); } diff --git a/include/tvm/script/ir_builder/ir/ir.h b/include/tvm/script/ir_builder/ir/ir.h index f0e7cc6f5c..10996a7b10 100644 --- a/include/tvm/script/ir_builder/ir/ir.h +++ b/include/tvm/script/ir_builder/ir/ir.h @@ -37,6 +37,21 @@ namespace ir { */ TVM_DLL IRModuleFrame IRModule(); +/*! + * \brief Declare a Function without given the specific function implementation. + * \note It is usually used in cross-function call. And we can specify the function by `DefFunction` + * \param func_name The function unique name. + * \return The corresponding GlobalVar. + */ +TVM_DLL GlobalVar DeclFunction(const String& func_name); + +/*! + * \brief Define the function which is declared before. + * \param func_name The function unique name. + * \param func The given function implementation + */ +TVM_DLL void DefFunction(const String& func_name, const BaseFunc& func); + } // namespace ir } // namespace ir_builder } // namespace script diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h new file mode 100644 index 0000000000..a1e908aef3 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +/*! \brief The base ir_builder frame for the relax dialect. */ +class RelaxFrameNode : public IRBuilderFrameNode { + public: + void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); +}; + +class RelaxFrame : public IRBuilderFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RelaxFrame, IRBuilderFrame, RelaxFrameNode); + + protected: + RelaxFrame() = default; +}; + +/*! \brief The base ir_builder frame for frames with SeqExpr + i.e. Functions, If branches + */ +class SeqExprFrameNode : public RelaxFrameNode { + public: + /*! \brief The binding blocks inside the frame. */ + Array binding_blocks; + /*! \brief The frame output expr. `NullOpt` when undefined. */ + Optional output; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); + + public: + void ExitWithScope() override; +}; + +class SeqExprFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SeqExprFrame, RelaxFrame, SeqExprFrameNode); +}; + +/*! \brief The ir_builder frame for the relax function. */ +class FunctionFrameNode : public SeqExprFrameNode { + public: + /*! + * \brief The function name. + * \note The name will not be specified in constructor, so it is "Optional", + * However, we must specify the name by `R.func_name` before exit this frame. + */ + Optional name; + /*! \brief The function params. */ + Array params; + /*! + * \brief The function return type. + * \note Usually the function return type can be deduced by the function body. + * But we can use this field to specify a more "accurate" return type. + * i.e. If the `ret_type` is None, try to use the deduced type from body + * If the `ret_type` is not None, check the deduced type is a base type of the given one. + */ + Optional ret_type; + /*! \brief The function attributes. */ + Map attrs; + /*! \brief The block builder to create Relax function. */ + tvm::relax::BlockBuilder block_builder; + + void VisitAttrs(tvm::AttrVisitor* v) { + SeqExprFrameNode::VisitAttrs(v); + v->Visit("name", &name); + v->Visit("params", ¶ms); + v->Visit("ret_type", &ret_type); + v->Visit("attrs", &attrs); + v->Visit("binding_blocks", &binding_blocks); + v->Visit("output", &output); + // `block_builder` is not visited. + } + + static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); + + public: + void ExitWithScope() final; +}; + +class FunctionFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionFrame, SeqExprFrame, FunctionFrameNode); +}; + +/*! \brief The ir_builder frame for relax binding blocks. */ +class BlockFrameNode : public RelaxFrameNode { + public: + /*! \brief The flag that indicates whether the block is a dataflow block. */ + bool is_dataflow; + /*! \brief The variables emitted in this block. */ + Array emitted_vars; + /*! + * \brief (Only used for a dataflow block.) A boolean indicating if the dataflow block is ended of + * construction. If it is true, any new binding trying to be emitted into this block will cause an + * error. + */ + bool block_ended; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("is_dataflow", &is_dataflow); + v->Visit("emitted_vars", &emitted_vars); + v->Visit("block_ended", &block_ended); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); + + public: + void EnterWithScope() final; + void ExitWithScope() final; +}; + +class BlockFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, RelaxFrame, BlockFrameNode); +}; + +/*! + * \brief A frame that represents if statement. + * + * \sa IfFrame + */ +class IfFrameNode : public RelaxFrameNode { + public: + /*! \brief The condition of the if statement. */ + tvm::relax::Expr condition; + /*! \brief The Bindings in the true branch. */ + Optional then_expr; + /*! \brief The Bindings in the false branch. */ + Optional else_expr; + /*! \brief The Binding var. */ + tvm::relax::Var var; + /*! \brief The binding var name. */ + String var_name; + + void VisitAttrs(tvm::AttrVisitor* v) { + RelaxFrameNode::VisitAttrs(v); + v->Visit("condition", &condition); + v->Visit("then_expr", &then_expr); + v->Visit("else_expr", &else_expr); + v->Visit("var", &var); + v->Visit("var_name", &var_name); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to IfFrameNode. + * + * \sa IfFrameNode + */ +class IfFrame : public RelaxFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(IfFrame, RelaxFrame, IfFrameNode); +}; + +/*! + * \brief A frame that represents then. + * + * \sa ThenFrame + */ +class ThenFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ThenFrameNode. + * + * \sa ThenFrameNode + */ +class ThenFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ThenFrame, SeqExprFrame, ThenFrameNode); +}; + +/*! + * \brief A frame that represents else. + * + * \sa ElseFrame + */ +class ElseFrameNode : public SeqExprFrameNode { + public: + static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; + TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); + + public: + /*! + * \brief The method called when entering RAII scope. + * \sa tvm::support::With + */ + void EnterWithScope() final; + /*! + * \brief The method called when exiting RAII scope. + * \sa tvm::support::With + */ + void ExitWithScope() final; +}; + +/*! + * \brief Managed reference to ElseFrameNode. + * + * \sa ElseFrameNode + */ +class ElseFrame : public SeqExprFrame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ElseFrame, SeqExprFrame, ElseFrameNode); +}; + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ diff --git a/include/tvm/script/ir_builder/relax/ir.h b/include/tvm/script/ir_builder/relax/ir.h new file mode 100644 index 0000000000..2499a87f51 --- /dev/null +++ b/include/tvm/script/ir_builder/relax/ir.h @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ + +#include +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +////////////////////////////// Tensor Type ////////////////////////////// + +/*! \brief A temporary Tensor type for `R.Tensor` in ir_builder. */ +class TensorTypeNode : public runtime::Object { + public: + /*! \brief The type, usually is DynTensorType */ + tvm::relax::DynTensorType type; + /*! \brief The shape, which is optional. */ + Optional shape; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("type", &type); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "script.ir_builder.relax.TensorType"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorTypeNode, runtime::Object); +}; + +class TensorType : public runtime::ObjectRef { + public: + TVM_DLL explicit TensorType(tvm::relax::DynTensorType type, Optional shape); + + TVM_DEFINE_OBJECT_REF_METHODS(TensorType, ObjectRef, TensorTypeNode); +}; + +/*! + * \brief Create a TensorType for a DynTensor. + * \param shape The shape of the tensor. It's runtime dependent if `shape` is None. + * \param dtype The element data type of the tensor. It's runtime dependent if `dtype` is None. + * \param ndim The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. + * \return The TensorType that is only used in ir_builder. + */ +TVM_DLL TensorType Tensor(Optional> shape, DataType dtype, int ndim = -1); + +/////////////////////////////// Function //////////////////////////////// + +/*! + * \brief Start a function frame. + * \return The created ir_builder Function frame. + */ +TVM_DLL FunctionFrame Function(); + +/*! + * \brief Add a parameter to the last function frame. + * \param name The name of the parameter. + * \param type The type of the parameter. + * \param shape The shape of the parameter. + * \return The created function parameter var. + */ +TVM_DLL tvm::relax::Var Arg(const String& name, const Type& type, + const tvm::relax::ShapeExpr& shape); + +/*! + * \brief Specify the name of the last function frame. + * \param name The function name. + */ +TVM_DLL void FuncName(const String& name); + +/*! + * \brief Specify the attrs of the last function frame. + * \param attrs The function attrs. + */ +TVM_DLL void FuncAttrs(Map attrs); + +/*! + * \brief Specify the return type of the last function frame. + * \param ret_type The return type. Note: it's a standard `tvm::Type` instead of TensorType. + */ +TVM_DLL void FuncRetType(tvm::Type ret_type); + +/*! + * \brief Specify the return value of the last function frame. + * \param value The return value. + */ +TVM_DLL void FuncRetValue(const tvm::relax::Expr& value); + +///////////////////////////// BindingBlock ////////////////////////////// + +/*! + * \brief Start a binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame BindingBlock(); + +/*! + * \brief Start a dataflow binding block frame. + * \return The created ir_builder Block frame. + */ +TVM_DLL BlockFrame Dataflow(); + +/*! + * \brief Expose the dataflow block output variables as global ones + * \param vars The output variables of a dataflow block + */ +TVM_DLL void DataflowBlockOutput(const Array& vars); + +////////////////////////////// Bindings //////////////////////////////// + +/*! + * \brief Emit a binding to the last binding block frame. + * \param value The right side value of the bindings to be emitted. + * \param is_dataflow_var A boolean indicating if the emitted binding variable is a dataflow + * variable. + * \return The left side var of the emitted binding. + */ +TVM_DLL tvm::relax::Var Emit(const tvm::relax::Expr& value, bool is_dataflow_var); + +/*! + * \brief Emit a match_shape binding to the last binding block frame. + * \param value The value of the MatchShape to be emitted. + * \param pattern The pattern of the MatchShape to be emitted. + * \param emit_var A boolean indicating if the MatchShape contains the emitted variable. + * \param is_dataflow_var A boolean indicating if the emitted variable is a dataflow variable when + * `emit_var` is true. When `emit_var` is false, the value of this flag will be ignored. + * \return The emitted var if `emit_var` is true. Otherwise, return `NullOpt`. + */ +TVM_DLL Optional EmitMatchShape(const tvm::relax::Expr& value, // + const Array& pattern, // + bool emit_var, // + bool is_dataflow_var); + +///////////////////////////// Type Deduce ////////////////////////////// + +/*! + * \brief Annotate and check the type and shape of relax var. + * \param var The input var to be annotated. + * \param type The given type. + * \param shape The given shape, which can be undefined. + * \note This function will check if the type of var is compatible with the given type. + * And we annotate to the var with more detailed type. + */ +TVM_DLL void AnnotateTypeShape(const tvm::relax::Var& var, const Type& type, + const Optional& shape); + +///////////////////////////// If Then Else ///////////////////////////// + +/*! + * \brief Create an if statement. + * \param condition The condition of if statement. + * \return The result IfFrame. + */ +IfFrame If(tvm::relax::Expr condition); +/*! + * \brief Create a then. + * \return The result ThenFrame. + */ +ThenFrame Then(); +/*! + * \brief Create an else. + * \return The result ElseFrame. + */ +ElseFrame Else(); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_IR_H_ diff --git a/python/tvm/meta_schedule/__init__.py b/python/tvm/meta_schedule/__init__.py index 30a4fc6d94..87e44fdbf8 100644 --- a/python/tvm/meta_schedule/__init__.py +++ b/python/tvm/meta_schedule/__init__.py @@ -24,6 +24,7 @@ measure_callback, mutator, postproc, + relax_integration, relay_integration, runner, schedule, diff --git a/python/tvm/meta_schedule/relax_integration.py b/python/tvm/meta_schedule/relax_integration.py index 41bacd2a09..587dde7dfb 100644 --- a/python/tvm/meta_schedule/relax_integration.py +++ b/python/tvm/meta_schedule/relax_integration.py @@ -15,18 +15,44 @@ # specific language governing permissions and limitations # under the License. """Meta schedule integration with high-level IR""" -from typing import List, Dict, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +# isort: off +from typing_extensions import Literal + +# isort: on from tvm._ffi import get_global_func from tvm.ir import IRModule -from tvm.meta_schedule import ExtractedTask +from tvm.ir.transform import PassContext +from tvm.runtime import NDArray from tvm.target import Target -from tvm.runtime import NDArray +from .builder import Builder +from .cost_model import CostModel +from .database import Database +from .extracted_task import ExtractedTask +from .logging import get_loggers_from_work_dir +from .measure_callback import MeasureCallback +from .runner import Runner +from .search_strategy import SearchStrategy +from .space_generator import SpaceGenerator +from .task_scheduler import TaskScheduler +from .tune import tune_tasks +from .tune_context import TuneContext +from .utils import fork_seed +if TYPE_CHECKING: + from tvm import relax -def extract_task_from_relax( - mod: IRModule, +_extract_task_func = get_global_func( # pylint: disable=invalid-name + "relax.backend.MetaScheduleExtractTask", + allow_missing=False, +) + + +def extract_tasks( + mod: Union[IRModule, "relax.Function"], target: Target, params: Optional[Dict[str, NDArray]] = None, ) -> List[ExtractedTask]: @@ -34,7 +60,7 @@ def extract_task_from_relax( Parameters ---------- - mod : tvm.IRModule + mod : Union[IRModule, relax.Function] The module or function to tune target : tvm.target.Target The compilation target @@ -48,20 +74,190 @@ def extract_task_from_relax( from tvm.relax.expr import Function as RelaxFunc from tvm.relax.transform import BindParams - # todo(@yongwww): fix circular import error, - # update type hint of mod to Union[IRModule, RelaxFunc] - extract_task_func = get_global_func( - "relax.backend.MetaScheduleExtractTask", - allow_missing=False, + # pylint: enable=import-outside-toplevel + if isinstance(mod, RelaxFunc): + mod = IRModule({"main": mod}) + if not isinstance(target, Target): + target = Target(target) + if params: + mod = BindParams("main", params)(mod) + return list(_extract_task_func(mod, target)) + + +def extracted_tasks_to_tune_contexts( + extracted_tasks: List[ExtractedTask], + work_dir: str, + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + num_threads: Union[Literal["physical", "logical"], int] = "physical", + seed: Optional[int] = None, +) -> Tuple[List[TuneContext], List[float]]: + """Convert ExtractedTask to TuneContext. + + Parameters + ---------- + tasks : List[ExtractedTask] + The tasks to be converted + work_dir : str + The working directory to store logs and databases + space : SpaceGenerator.SpaceGeneratorType + The space generator to use. + strategy : SearchStrategy.SearchStrategyType + The search strategy to use. + num_threads : Union[Literal["physical", "logical"], int] + The number of threads to use in multi-threaded search algorithm. + seed : Optional[int] + The random seed to use. + + Returns + ------- + tasks : List[TuneContext] + The converted tasks + task_weights : List[float] + The weights of the tasks + """ + tasks: List[TuneContext] = [] + task_weights: List[float] = [] + for task, logger, rand_state in zip( + extracted_tasks, + get_loggers_from_work_dir(work_dir, [t.task_name for t in extracted_tasks]), + fork_seed(seed, n=len(extracted_tasks)), + ): + tasks.append( + TuneContext( + mod=task.dispatched[0], + target=task.target, + space_generator=space, + search_strategy=strategy, + task_name=task.task_name, + logger=logger, + rand_state=rand_state, + num_threads=num_threads, + ).clone() + ) + task_weights.append(task.weight) + return tasks, task_weights + + +def tune_relax( + mod: Union[IRModule, "relax.Function"], + params: Dict[str, NDArray], + target: Union[str, Target], + work_dir: str, + max_trials_global: int, + *, + max_trials_per_task: Optional[int] = None, + num_trials_per_iter: int = 64, + builder: Builder.BuilderType = "local", + runner: Runner.RunnerType = "local", + database: Database.DatabaseType = "json", + cost_model: CostModel.CostModelType = "xgb", + measure_callbacks: MeasureCallback.CallbackListType = "default", + task_scheduler: TaskScheduler.TaskSchedulerType = "gradient", + space: SpaceGenerator.SpaceGeneratorType = "post-order-apply", + strategy: SearchStrategy.SearchStrategyType = "evolutionary", + seed: Optional[int] = None, +) -> Database: + """Tune a Relax program. + + Parameters + ---------- + mod : Union[IRModule, relax.Function] + The module or function to tune + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + target : Union[Target, str] + The compilation target + work_dir : str + The working directory to store the tuning records + max_trials_global : int + The maximum number of trials to run + max_trials_per_task : Optional[int] + The maximum number of trials to run for each task + num_trials_per_iter : int + The number of trials to run per iteration + builder : BuilderType + The builder to use + runner : RunnerType + The runner to use + database : DatabaseType + The database to use + cost_model : CostModelType + The cost model to use + measure_callbacks : CallbackListType + The measure callbacks to use + task_scheduler : TaskSchedulerType + The task scheduler to use + space : SpaceGeneratorType + The space generator to use + strategy : SearchStrategyType + The search strategy to use + seed : Optional[int] + The random seed + + Returns + ------- + database : Database + The database that contains the tuning records + """ + tasks, task_weights = extracted_tasks_to_tune_contexts( + extracted_tasks=extract_tasks(mod, target, params), + work_dir=work_dir, + space=space, + strategy=strategy, + seed=seed, + ) + return tune_tasks( + tasks=tasks, + task_weights=task_weights, + work_dir=work_dir, + max_trials_global=max_trials_global, + max_trials_per_task=max_trials_per_task, + num_trials_per_iter=num_trials_per_iter, + builder=builder, + runner=runner, + database=database, + cost_model=cost_model, + measure_callbacks=measure_callbacks, + task_scheduler=task_scheduler, ) - if isinstance(mod, RelaxFunc): - mod = IRModule.from_expr(mod) +def compile_relax( + database: Database, + mod: IRModule, + target: Union[Target, str], + params: Optional[Dict[str, NDArray]], +) -> "relax.vm.Executable": + """Compile a relax program with a MetaSchedule database. + + Parameters + ---------- + database : Database + The database to use + mod : IRModule + The Relax program to be compiled + target : tvm.target.Target + The compilation target + params : Optional[Dict[str, tvm.runtime.NDArray]] + The associated parameters of the program + + Returns + ------- + lib : relax.vm.Executable + The built runtime module or vm Executable for the given relax workload. + """ + # pylint: disable=import-outside-toplevel + from tvm.relax.transform import BindParams, MetaScheduleApplyDatabase + from tvm.relax.vm import build as relax_build + + # pylint: enable=import-outside-toplevel if not isinstance(target, Target): target = Target(target) - if params: mod = BindParams("main", params)(mod) - return list(extract_task_func(mod, target)) + with target, database, PassContext(opt_level=3): + relax_mod = MetaScheduleApplyDatabase()(mod) + relax_ex = relax_build(relax_mod, target=target) + return relax_ex diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 8a4fe299eb..18b4921289 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -17,15 +17,17 @@ # pylint: disable=invalid-name, unused-import, super-init-not-called # pylint: disable=redefined-builtin """The expression nodes of Relax.""" -from typing import List, Optional, Union -import tvm._ffi +from typing import Any, List, Optional, Union + import tvm -from ..ir import Node, Span, SourceName, BaseFunc -from ..runtime import String +import tvm._ffi + +from .. import relay +from ..ir import BaseFunc, Node, SourceName, Span from ..relay import Id, Tuple, TupleGetItem +from ..runtime import String from ..tir import PrimExpr -from . import _ffi_api -from .. import relay +from . import _ffi_api, ty Expr = relay.Expr Type = relay.Type @@ -98,6 +100,12 @@ def name_hint(self): name = str(self.vid.name_hint) return name + def __call__(self, *args: Any, attrs=None) -> Call: + if self.checked_type and isinstance(self.checked_type, ty.FuncType): + return Call(self, args, attrs=attrs) + else: + raise TypeError("Only vars with function type can be called") + @tvm._ffi.register_object("relax.expr.DataflowVar") class DataflowVar(Var): diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index ca70193968..7428ea590b 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -21,3 +21,4 @@ from .base import * from .tensor import * from .op_attrs import * +from . import builtin diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index f921d367da..c367d630de 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -15,15 +15,15 @@ # specific language governing permissions and limitations # pylint: disable=redefined-builtin """The base Relax operators.""" -from typing import List, Optional, Union +from typing import Union, List, Optional import tvm from tvm.runtime.object import Object -from ...ir import Array -from ..expr import Call, Expr, ExternFunc, ShapeExpr, Tuple -from ..ty import DynTensorType, TupleType from . import _ffi_api +from ..expr import Expr, ShapeExpr, Tuple, Call, ExternFunc +from ..ty import DynTensorType, TupleType +from ...ir import Array py_print = print # pylint: disable=invalid-name diff --git a/python/tvm/relax/op/builtin/__init__.py b/python/tvm/relax/op/builtin/__init__.py new file mode 100644 index 0000000000..04837724b1 --- /dev/null +++ b/python/tvm/relax/op/builtin/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=wildcard-import, redefined-builtin +"""Relax builtin operators.""" + +from .builtin import * diff --git a/python/tvm/relax/op/builtin/_ffi_api.py b/python/tvm/relax/op/builtin/_ffi_api.py new file mode 100644 index 0000000000..42fe8cb652 --- /dev/null +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""FFI APIs for tvm.relax.op.builtin""" +import tvm._ffi + +tvm._ffi._init_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py new file mode 100644 index 0000000000..0c80ba73d6 --- /dev/null +++ b/python/tvm/relax/op/builtin/builtin.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""The builtin Relax operators.""" + +from typing import List, Union +from tvm.ir.expr import PrimExpr +from . import _ffi_api +from ...expr import ShapeExpr, Call + + +# TODO(relax-team): add documents +def alloc_tensor( + shape: Union[ShapeExpr, PrimExpr, List[PrimExpr]], dtype: str, runtime_device_index: int +) -> Call: + if not isinstance(shape, ShapeExpr): + if not isinstance(shape, (tuple, list)): + shape = (shape,) + shape = ShapeExpr(shape) + return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) diff --git a/python/tvm/relax/op/tensor.py b/python/tvm/relax/op/tensor.py index 6bc90cec03..3d48973f01 100644 --- a/python/tvm/relax/op/tensor.py +++ b/python/tvm/relax/op/tensor.py @@ -13,6 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations +# pylint: disable=redefined-builtin """Basic tensor operations.""" import numpy as np import tvm @@ -29,8 +30,45 @@ def multiply(lhs: Expr, rhs: Expr) -> Expr: return _ffi_api.multiply(lhs, rhs) -@tvm.register_func("relax.run.unique") def unique( + data: Expr, + sorted: bool = True, + return_inverse: bool = False, + return_counts: bool = False, + dim: int = -1, +) -> Expr: + """Find the unique elements and the new index of each item in a given tensor. + + Parameters + ---------- + data : Expr + The input tensor. + + sorted: bool + Whether to sort the unique elements in ascending order before + returning as output. + + return_inverse: bool + Whether to return an additional tensor with indices for where elements in + the original input ended up in the returned unique list. + + return_counts: bool + Whether to return an additional tensor with counts of each unique elements. + + dim: int + The dimension to apply unique. If negative, the unique of the flattened input is returned. + + Returns + ------- + ret: Expr + The created relax call with + """ + + return _ffi_api.unique(data, sorted, return_inverse, return_counts, dim) + + +@tvm.register_func("relax.run.unique") +def numpy_unique( a: tvm.nd.array, sort: int, return_inverse: int, diff --git a/python/tvm/relax/testing/relay_translator.py b/python/tvm/relax/testing/relay_translator.py index 9f78e1a82b..77968de0e2 100644 --- a/python/tvm/relax/testing/relay_translator.py +++ b/python/tvm/relax/testing/relay_translator.py @@ -18,15 +18,16 @@ """Relay to Relax translator.""" from __future__ import annotations + from typing import Any, Dict, List, Optional + import tvm -from tvm.ir.module import IRModule from tvm import relax, relay +from tvm.ir.module import IRModule from tvm.relax.testing import nn from tvm.relay.backend.te_compiler import select_implementation from tvm.runtime import NDArray from tvm.target import Target -from tvm.meta_schedule.utils import autotvm_silencer def from_relay( @@ -83,6 +84,7 @@ def from_relay( pass_config = { "relay.FuseOps.max_depth": 1, # Disable relay fusion "relay.backend.use_meta_schedule": True, + "relay.backend.use_meta_schedule_dispatch": True, } if relay_params: @@ -132,19 +134,24 @@ def visit_func(node): call = relax.call_tir(tir_gvar, new_args, out_type.shape, out_type.dtype) var = bb.emit(call) else: - best_impl, outputs = select_implementation( - node.op, - attrs, - te_inputs, - out_type, - target, - use_autotvm=False, - ) - compute_func = best_impl.compute - name_hint = op_name.split(".")[-1] - var = bb.emit_te( - compute_func, attrs, new_args, node.checked_type, primfunc_name_hint=name_hint - ) + with target: + best_impl, outputs = select_implementation( + node.op, + attrs, + te_inputs, + out_type, + target, + use_autotvm=False, + ) + compute_func = best_impl.compute + name_hint = op_name.split(".")[-1] + var = bb.emit_te( + compute_func, + attrs, + new_args, + node.checked_type, + primfunc_name_hint=name_hint, + ) output_var = var var_map[node] = var @@ -189,8 +196,10 @@ def visit_func(node): # Since optimization passes and OpStrategy are highly context-dependent, # we match the exact same context with `extract_task_from_relay()` env - with autotvm_silencer(), target, tvm.transform.PassContext( - opt_level=opt_level, config=pass_config, disabled_pass=disabled_pass + with tvm.transform.PassContext( + opt_level=opt_level, + config=pass_config, + disabled_pass=disabled_pass, ): mod = tvm.IRModule.from_expr(func) mod = seq(mod) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index 12e00dc2e9..47901357ba 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -19,13 +19,10 @@ import functools import inspect import types -from typing import Callable, Dict, Union, Optional, List -import numpy as np +from typing import Callable, Dict, List, Optional, Union +import numpy as np import tvm.ir -from tvm.target import Target -from tvm.meta_schedule.tune import TuneConfig -from tvm.meta_schedule.database import PyDatabase from . import _ffi_api @@ -162,73 +159,6 @@ def ResolveGlobals() -> tvm.ir.transform.Pass: return _ffi_api.ResolveGlobals() -def MetaScheduleTuneTIR( - target: Union[str, Target], config: TuneConfig, work_dir: str -) -> tvm.ir.transform.Pass: - """Tune TIR with MetaSchedule. - - Parameters - ---------- - target: Union[str, Target] - target info - config: TuneConfig - MetaSchedule tuning info - work_dir: str - work directory - Returns - ------- - ret: tvm.ir.transform.Pass - - """ - if isinstance(target, str): - target = tvm.target.Target(target) - return _ffi_api.MetaScheduleTuneTIR(target, config, work_dir) - - -def MetaScheduleTuneIRMod( - target: Union[str, Target], config: TuneConfig, work_dir: str -) -> tvm.ir.transform.Pass: - """Tune Relax IRModule with MetaSchedule. - - Parameters - ---------- - target: Union[str, Target] - target info - config: TuneConfig - MetaSchedule tuning info - work_dir: str - work directory - Returns - ------- - ret: tvm.ir.transform.Pass - - """ - if isinstance(target, str): - target = tvm.target.Target(target) - return _ffi_api.MetaScheduleTuneIRMod(target, config, work_dir) - - -def MetaScheduleApplyHistoryBest( - database: PyDatabase, - target: Target, -) -> tvm.ir.transform.Pass: - """Apply the best schedule from tuning database. - - Parameters - ---------- - database : PyDatabase - metaschedule tuning database - target: Target - target info - - Returns - ------- - ret: tvm.ir.transform.Pass - - """ - return _ffi_api.MetaScheduleApplyHistoryBest(database, target) - - def BindParams( func_name: str, params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]], @@ -353,6 +283,17 @@ def FuseTIR() -> tvm.ir.transform.Pass: return _ffi_api.FuseTIR() +def MetaScheduleApplyDatabase() -> tvm.ir.transform.Pass: + """Apply the best schedule from tuning database. + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for tir fusion. + """ + return _ffi_api.MetaScheduleApplyDatabase() + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/script/_parser/core/dispatch.py b/python/tvm/script/_parser/core/dispatch.py new file mode 100644 index 0000000000..f10b90961a --- /dev/null +++ b/python/tvm/script/_parser/core/dispatch.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type + +from .doc import AST + +if TYPE_CHECKING: + from .parser import Parser + + +ParseMethod = Callable[["Parser", AST], None] +ParseVTable: Dict[Tuple[str, str], ParseMethod] = {} + +OpMethod = Callable[..., Any] +OpVTable: Dict[Tuple[Type, AST, int], OpMethod] = {} + + +def register(token: str, type_name: str): + """Register a method for a dispatch token and type name""" + + def f(method: ParseMethod): + ParseVTable[(token, type_name)] = method + + return f + + +def get( + token: str, + type_name: str, + default: Optional[ParseMethod] = None, +) -> Optional[ParseMethod]: + return ParseVTable.get((token, type_name), default) + + +def register_op(ty: Type, op: AST, operand_index: int): # pylint: disable=invalid-name + def f(method: OpMethod): + OpVTable[(ty, op, operand_index)] = method + + return f + + +def get_op( # pylint: disable=invalid-name + ty: Type, + op: Type, + operand_index: int, + default: Optional[OpMethod] = None, +) -> Optional[OpMethod]: + return OpVTable.get((ty, op, operand_index), default) diff --git a/python/tvm/script/_parser/core/entry.py b/python/tvm/script/_parser/core/entry.py new file mode 100644 index 0000000000..afd3cb5027 --- /dev/null +++ b/python/tvm/script/_parser/core/entry.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""The entry point of TVM parser.""" +from typing import Any, Union + +from ...ir_builder import IRBuilder +from . import doc +from .diagnostics import Source +from .parser import Parser + + +def parse(program: Union[doc.AST, Any, str], extra_vars=None): + if extra_vars is None: + from tvm.script._parser import ir # pylint: disable=import-outside-toplevel + from tvm.script._parser import relax # pylint: disable=import-outside-toplevel + from tvm.script._parser import tir # pylint: disable=import-outside-toplevel + + extra_vars = { + "I": ir, + "ir": ir, + "T": tir, + "tir": tir, + "relax": relax, + "R": relax, + } + + source = Source(program) + parser = Parser(source) + with IRBuilder() as builder: + parser.parse(extra_vars=extra_vars) + return builder.get() diff --git a/python/tvm/script/_parser/core/evaluator.py b/python/tvm/script/_parser/core/evaluator.py new file mode 100644 index 0000000000..0c2ccee48a --- /dev/null +++ b/python/tvm/script/_parser/core/evaluator.py @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""AST Evaluation""" +import ast +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union + +from . import dispatch, doc + +if TYPE_CHECKING: + from .parser import Parser + +DEFAULT_OP: Dict[Type, Callable[..., Any]] = { + doc.Add: lambda a, b: a + b, + doc.Sub: lambda a, b: a - b, + doc.Mult: lambda a, b: a * b, + doc.Div: lambda a, b: a / b, + doc.FloorDiv: lambda a, b: a // b, + doc.Mod: lambda a, b: a % b, + doc.LShift: lambda a, b: a << b, + doc.RShift: lambda a, b: a >> b, + doc.BitOr: lambda a, b: a | b, + doc.BitXor: lambda a, b: a ^ b, + doc.BitAnd: lambda a, b: a & b, + doc.MatMult: lambda a, b: a @ b, + # fmt: off + doc.Pow: lambda a, b: a**b, + # fmt: on + doc.Eq: lambda a, b: a == b, + doc.NotEq: lambda a, b: a != b, + doc.Lt: lambda a, b: a < b, + doc.LtE: lambda a, b: a <= b, + doc.Gt: lambda a, b: a > b, + doc.GtE: lambda a, b: a >= b, + doc.Is: lambda a, b: a is b, + doc.IsNot: lambda a, b: a is not b, + doc.In: lambda a, b: a in b, + doc.NotIn: lambda a, b: a not in b, + doc.And: lambda a, b: a and b, + doc.Or: lambda a, b: a or b, + doc.Invert: lambda a: ~a, + doc.Not: lambda a: not a, + doc.UAdd: lambda a: +a, + doc.USub: lambda a: -a, +} + + +class ExprEvaluator: + + parser: "Parser" + value_table: Dict[str, Any] + new_value_count: int + + def __init__(self, parser: "Parser", value_table: Dict[str, Any]) -> None: + super().__init__() + self.parser = parser + self.value_table = value_table + self.new_value_count = 0 + + @staticmethod + def eval(parser: "Parser", value_table: Dict[str, Any], node: doc.AST) -> Any: + self = ExprEvaluator(parser, value_table) + result = self._visit(node) # pylint: disable=protected-access + if isinstance(result, doc.Name): + if result.id not in self.value_table: + self.parser.report_error(result, f"Undefined variable: {result.id}") + return self.value_table[result.id] + if isinstance(result, doc.Constant): + return result.value + raise TypeError(f"Unexpected result type: {type(result)}") + + def _add_intermediate_result(self, value: Any) -> doc.Name: + name = f"__tvm_tmp_value_{self.new_value_count}" + self.new_value_count += 1 + self.value_table[name] = value + lineno = 0 + col_offset = 0 + return doc.Name( + id=name, + ctx=doc.Load( + lineno=lineno, + col_offset=col_offset, + end_lineno=None, + end_col_offset=None, + ), + lineno=lineno, + col_offset=col_offset, + end_lineno=None, + end_col_offset=None, + ) + + def _visit(self, node: doc.AST) -> Any: + if isinstance(node, list): + return [self._visit(n) for n in node] + if isinstance(node, tuple): + return tuple(self._visit(n) for n in node) + assert isinstance(node, doc.AST) + if isinstance(node, doc.Name): + if node.id not in self.value_table: + self.parser.report_error(node, f"Undefined variable: {node.id}") + return node + if isinstance( + node, + ( + doc.Constant, + doc.expr_context, + doc.operator, + doc.boolop, + doc.unaryop, + doc.cmpop, + ), + ): + return node + if not isinstance(node, (doc.expr, doc.slice)): + return node + if isinstance(node, doc.Lambda): + return self._eval_lambda(node) + fields = {} + for field in node.__class__._FIELDS: # pylint: disable=protected-access + attr = getattr(node, field) + if isinstance(attr, (doc.AST, tuple, list)): + fields[field] = self._visit(attr) + else: + fields[field] = attr + try: + if isinstance(node, doc.BoolOp): + value = self._eval_bool_op(fields) + elif isinstance(node, doc.Compare): + value = self._eval_compare(fields) + elif isinstance(node, doc.UnaryOp): + value = self._eval_unary_op(fields) + elif isinstance(node, doc.BinOp): + value = self._eval_bin_op(fields) + elif isinstance(node, doc.Slice): + value = self._eval_slice(fields) + else: + value = self._eval_expr(node.__class__(**fields)) + except Exception as e: # pylint: disable=broad-except,invalid-name + self.parser.report_error(node, str(e)) + return self._add_intermediate_result(value) + + def _eval_lambda(self, node: doc.Lambda) -> Any: + try: + value = self._eval_expr(node) + except Exception as e: # pylint: disable=broad-except,invalid-name + self.parser.report_error(node, str(e)) + return self._add_intermediate_result(value) + + def _eval_bool_op(self, fields: Dict[str, Any]) -> Any: + op = fields["op"] + if not isinstance(op, (doc.And, doc.Or)): + raise TypeError(f"Unexpected operator: {op}") + value = self._eval_expr(fields["values"][0]) + for rhs in fields["values"][1:]: + value = _eval_op(op, values=[value, self._eval_expr(rhs)]) + return value + + def _eval_compare(self, fields: Dict[str, Any]) -> Any: + value = self._eval_expr(fields["left"]) + for op, rhs in zip(fields["ops"], fields["comparators"]): + value = _eval_op(op, values=[value, self._eval_expr(rhs)]) + return value + + def _eval_unary_op(self, fields: Dict[str, Any]) -> Any: + value = self._eval_expr(fields["operand"]) + value = _eval_op(fields["op"], values=[value]) + return value + + def _eval_bin_op(self, fields: Dict[str, Any]) -> Any: + return _eval_op( + fields["op"], + values=[ + self._eval_expr(fields["left"]), + self._eval_expr(fields["right"]), + ], + ) + + def _eval_slice(self, fields: Dict[str, Any]) -> Any: + lower, upper, step = fields["lower"], fields["upper"], fields["step"] + + lower = self._eval_expr(lower) if lower is not None else None + upper = self._eval_expr(upper) if upper is not None else None + step = self._eval_expr(step) if step is not None else None + + return slice(lower, upper, step) + + def _eval_expr(self, v: Any) -> Any: + return _eval_expr(v, self.value_table) + + +def eval_expr( + parser: "Parser", + node: Union[doc.expr, doc.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + value_table = {} + if dict_globals is not None: + value_table.update(dict_globals) + return ExprEvaluator.eval(parser, value_table, node) + + +def eval_assign( + parser: "Parser", + target: doc.expr, + source: Any, +) -> Dict[str, Any]: + try: + return _eval_assign(target, source) + except Exception as e: # pylint: disable=broad-except,invalid-name + parser.report_error(target, f"Failed to evaluate assignment: {str(e)}") + raise + + +def _eval_expr( + node: Union[doc.expr, doc.Expression], + dict_globals: Optional[Dict[str, Any]], +) -> Any: + node = doc.from_doc(node) + if isinstance(node, ast.expr): + node = ast.Expression(body=node) + assert isinstance(node, ast.Expression), "Expects an ast.Expression, but gets: " + str(node) + if dict_globals is None: + dict_globals = {} + node = ast.fix_missing_locations(node) + exe = compile(node, filename="", mode="eval") + return eval(exe, dict_globals) # pylint: disable=eval-used + + +def _eval_op( + op: doc.AST, + values: List[Any], +): + op_type = type(op) # pylint: disable=protected-access + for i, v in enumerate(values): + v_type = getattr(type(v), "_dispatch_type", None) + if v_type is None: + continue + f = dispatch.get_op(ty=v_type, op=op_type, operand_index=i, default=None) + if f is not None: + return f(*values) + return DEFAULT_OP[op_type](*values) + + +def _eval_assign( + target: doc.expr, + source: Any, +) -> Dict[str, Any]: + target = doc.from_doc(target) + assert isinstance(target, ast.expr) + RHS_VAR_NAME = "__tvm_rhs_var__" # pylint: disable=invalid-name + rhs_var_name = RHS_VAR_NAME + dict_locals = {rhs_var_name: source} + mod = ast.fix_missing_locations( + ast.Module( + body=[ + ast.Assign( + targets=[target], + value=ast.Name( + id=rhs_var_name, + ctx=ast.Load(), + ), + ) + ], + type_ignores=[], + ) + ) + exe = compile(mod, filename="", mode="exec") + exec(exe, {}, dict_locals) # pylint: disable=exec-used + del dict_locals[rhs_var_name] + return dict_locals diff --git a/python/tvm/script/_parser/core/parser.py b/python/tvm/script/_parser/core/parser.py new file mode 100644 index 0000000000..7846bd8c0f --- /dev/null +++ b/python/tvm/script/_parser/core/parser.py @@ -0,0 +1,300 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +"""The core parser""" +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from tvm._ffi.base import TVMError +from tvm.error import DiagnosticError + +from . import dispatch, doc +from .diagnostics import Diagnostics, Source +from .evaluator import eval_assign, eval_expr + +DEFAULT_VISIT = { + "Interactive", + "Module", + "Expression", + "Pass", +} + + +def _deferred(f: Callable[[], None]): + @contextmanager + def context(): + try: + yield + finally: + f() + + return context() + + +class VarTableFrame: + vars: Set[str] + + def __init__(self): + self.vars = set() + + def add(self, var: str): + if var in self.vars: + raise ValueError(f"Variable {var} already defined in current scope") + self.vars.add(var) + + def pop_all(self, fn_pop: Callable[[str], None]): + for var in self.vars: + fn_pop(var) + self.vars.clear() + + +class VarTable: + + frames: List[VarTableFrame] + name2value: Dict[str, List[Any]] + + def __init__(self): + self.frames = [] + self.name2value = defaultdict(list) + + def with_frame(self): + def pop_frame(): + frame = self.frames.pop() + frame.pop_all(lambda name: self.name2value[name].pop()) + + self.frames.append(VarTableFrame()) + return _deferred(pop_frame) + + def add(self, var: str, value: Any, allow_shadowing: bool = False): + # Skip if the key and value are equal to those in the var_table + if self.name2value[var] and self.name2value[var][-1] == value: + return + if allow_shadowing and var in self.frames[-1].vars: + # Shadowing + self.name2value[var][-1] = value + else: + self.frames[-1].add(var) + self.name2value[var].append(value) + + def get(self) -> Dict[str, Any]: + return {key: values[-1] for key, values in self.name2value.items() if values} + + def exist(self, value: Any): + for v in self.name2value.values(): + if v is value: + return True + return False + + +def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod: + def _wrapper(self: "Parser", node: doc.AST) -> None: + try: + return func(self, node) + except DiagnosticError: + raise + except Exception as e: # pylint: disable=broad-except,invalid-name + self.report_error(node, e) + raise + + return _wrapper + + +def _dispatch(self: "Parser", type_name: str) -> dispatch.ParseMethod: + for token in [self.dispatch_tokens[-1], "default"]: + func = dispatch.get(token=token, type_name=type_name, default=None) + if func is not None: + return _dispatch_wrapper(func) + return _dispatch_wrapper(lambda self, node: self.generic_visit(node)) + + +def _dispatch_optional(self: "Parser", type_name: str) -> Optional[dispatch.ParseMethod]: + for token in [self.dispatch_tokens[-1], "default"]: + func = dispatch.get(token=token, type_name=type_name, default=None) + if func is not None: + return _dispatch_wrapper(func) + return None + + +class Parser(doc.NodeVisitor): + """The TVMScript parser""" + + diag: Diagnostics + dispatch_tokens: List[str] + var_table: VarTable + + def __init__(self, source: Source) -> None: + self.diag = Diagnostics(source) + self.dispatch_tokens = ["default"] + self.var_table = VarTable() + + def parse(self, extra_vars: Optional[Dict[str, Any]] = None) -> Any: + if extra_vars is None: + extra_vars = {} + with self.var_table.with_frame(): + for k, v in extra_vars.items(): + self.var_table.add(k, v) + node = self.diag.source.as_ast() + self.visit(node) + + def with_dispatch_token(self, token: str): + def pop_token(): + self.dispatch_tokens.pop() + + self.dispatch_tokens.append(token) + return _deferred(pop_token) + + def eval_expr( + self, + node: Union[doc.Expression, doc.expr], + extra_vars: Optional[Dict[str, Any]] = None, + ) -> Any: + var_values = self.var_table.get() + if extra_vars is not None: + for k, v in extra_vars.items(): + var_values[k] = v + return eval_expr(self, node, var_values) + + def _duplicate_lhs_check(self, target: doc.expr) -> Union[bool, Set[str]]: + if isinstance(target, (doc.Tuple, doc.List)): + vars: Set[str] = set() # pylint: disable=redefined-builtin + for i in target.elts: + res = self._duplicate_lhs_check(i) + if isinstance(res, bool) and res: + return True + assert isinstance(res, set) + if vars & res: + return True + vars = vars.union(res) + return vars + elif isinstance(target, doc.Name): + return {target.id} + else: + self.report_error(target, "Invalid type in assign statement") + raise NotImplementedError + + def eval_assign( + self, + target: doc.expr, + source: Any, + bind_value: Callable[["Parser", doc.expr, str, Any], Any], + allow_shadowing: bool = False, + ) -> Dict[str, Any]: + if self._duplicate_lhs_check(target) is True: + self.report_error(target, "Duplicate vars assigned.") + var_values = eval_assign(self, target, source) + for k, v in var_values.items(): + var = bind_value(self, target, k, v) + self.var_table.add(k, var, allow_shadowing) + return var_values + + def report_error( + self, node: doc.AST, err: Union[Exception, str] + ) -> None: # pylint: disable=no-self-use + # Only take the last line of the error message + if isinstance(err, (TVMError, ValueError, TypeError)): + msg = list(filter(None, str(err).split("\n")))[-1] + else: + msg = str(err) + self.diag.error(node, msg) + + def visit(self, node: doc.AST) -> None: + if isinstance(node, (list, tuple)): + for item in node: + self.visit(item) + return + if not isinstance(node, doc.AST): + return + name = node.__class__.__name__.split(".")[-1] + if name in DEFAULT_VISIT: + func = self.generic_visit + else: + func = getattr(self, "visit_" + name, None) + if func is None: + raise NotImplementedError(f"Visitor of AST node is not implemented: {name}") + try: + func(node) + except DiagnosticError: + raise + except Exception as e: # pylint: disable=broad-except,invalid-name + self.report_error(node, str(e)) + raise + + def visit_body(self, node: List[doc.stmt]) -> Any: + for stmt in node: + self.visit(stmt) + + def visit_tvm_annotation(self, node: doc.expr) -> Any: + return _dispatch(self, "tvm_annotation")(self, node) + + def visit_FunctionDef(self, node: doc.FunctionDef) -> Any: # pylint: disable=invalid-name + if not node.decorator_list: + self.report_error(node, "Function must be decorated") + # TODO: only the last decorator is parsed + decorator = self.eval_expr(node.decorator_list[-1]) + if not hasattr(decorator, "dispatch_token"): + self.report_error(node, "The parser does not understand the decorator") + token = decorator.dispatch_token + func = dispatch.get(token=token, type_name="FunctionDef", default=None) + if func is None: + self.report_error(node, "The parser does not understand the decorator") + pre_func = _dispatch_optional(self, "pre_token_switch") + post_func = _dispatch_optional(self, "post_token_switch") + if pre_func: + pre_func(self, node) + _dispatch_wrapper(func)(self, node) + if post_func: + post_func(self, node) + + def visit_ClassDef(self, node: doc.ClassDef) -> Any: # pylint: disable=invalid-name + func = dispatch.get(token="ir", type_name="ClassDef", default=None) + if func is None: + self.report_error(node, "The parser does not understand the decorator") + _dispatch_wrapper(func)(self, node) + + def visit_arguments(self, node: doc.arguments) -> Any: + return _dispatch(self, "arguments")(self, node) + + def visit_For(self, node: doc.For) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "For")(self, node) + + def visit_While(self, node: doc.While) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "While")(self, node) + + def visit_With(self, node: doc.With) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "With")(self, node) + + def visit_Assign(self, node: doc.Assign) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Assign")(self, node) + + def visit_Expr(self, node: doc.Expr) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Expr")(self, node) + + def visit_If(self, node: doc.If) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "If")(self, node) + + def visit_AnnAssign(self, node: doc.AnnAssign) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "AnnAssign")(self, node) + + def visit_AugAssign(self, node: doc.AugAssign) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "AugAssign")(self, node) + + def visit_Assert(self, node: doc.Assert) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Assert")(self, node) + + def visit_Return(self, node: doc.Return) -> Any: # pylint: disable=invalid-name + return _dispatch(self, "Return")(self, node) diff --git a/python/tvm/script/_parser/ir/__init__.py b/python/tvm/script/_parser/ir/__init__.py new file mode 100644 index 0000000000..8cf9b50665 --- /dev/null +++ b/python/tvm/script/_parser/ir/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from . import parser as _parser +from .entry import ir_module, is_defined_in_class + +__all__ = ["ir_module", "is_defined_in_class"] diff --git a/python/tvm/script/_parser/ir/entry.py b/python/tvm/script/_parser/ir/entry.py new file mode 100644 index 0000000000..e0a0213cd1 --- /dev/null +++ b/python/tvm/script/_parser/ir/entry.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import inspect +from typing import Type + +from tvm.ir import IRModule + +from .._core import parse, utils + + +def is_defined_in_class(frames): + if len(frames) > 2: + maybe_class_frame = frames[2] + statement_list = maybe_class_frame[4] + if statement_list is None: + return False + first_statement = statement_list[0] + line = first_statement.strip() + if line.startswith("class "): + return True + if line.startswith("@") and "ir_module" in line: + return True + return False + + +def ir_module(f: Type) -> IRModule: + if not inspect.isclass(f): + raise TypeError(f"Expect a class, but got: {f}") + + return parse(f, utils.inspect_class_capture(f)) + + +setattr(ir_module, "dispatch_token", "ir") diff --git a/python/tvm/script/_parser/ir/parser.py b/python/tvm/script/_parser/ir/parser.py new file mode 100644 index 0000000000..eacbe9641c --- /dev/null +++ b/python/tvm/script/_parser/ir/parser.py @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from ...ir_builder import ir as I +from .._core import Parser, dispatch, doc + + +@dispatch.register(token="ir", type_name="ClassDef") +def _visit_class_def(self: Parser, node: doc.ClassDef) -> None: + with self.var_table.with_frame(): + with I.ir_module(): + for stmt in node.body: + if isinstance(stmt, doc.FunctionDef): + global_var = I.decl_function(stmt.name) + self.var_table.add(stmt.name, global_var) + with self.with_dispatch_token("ir"): + self.visit_body(node.body) + + +@dispatch.register(token="ir", type_name="Assign") +def _visit_assign(_self: Parser, _node: doc.Assign) -> None: + pass + + +@dispatch.register(token="ir", type_name="Expr") +def _visit_expr(_self: Parser, _node: doc.Expr) -> None: + pass diff --git a/python/tvm/script/_parser/relax/__init__.py b/python/tvm/script/_parser/relax/__init__.py new file mode 100644 index 0000000000..ed85bd8af6 --- /dev/null +++ b/python/tvm/script/_parser/relax/__init__.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from ...ir_builder.relax import * # pylint: disable=redefined-builtin +from ...ir_builder.relax import ir as _relax +from . import parser as _parser +from .entry import Callable, Tensor, function, match_shape + +__all__ = _relax.__all__ + ["Callable", "Tensor", "function", "match_shape"] diff --git a/python/tvm/script/_parser/relax/entry.py b/python/tvm/script/_parser/relax/entry.py new file mode 100644 index 0000000000..453afbaf17 --- /dev/null +++ b/python/tvm/script/_parser/relax/entry.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +import inspect +from typing import Callable as _Callable +from typing import List, Optional +from typing import TypeVar as _TypeVar +from typing import Union + +from tvm.ir import FuncType, TypeConstraint, TypeVar +from tvm.relax import Expr, Function, Type, Var +from tvm.tir import PrimExpr + +from ...ir_builder.relax import TensorType, tensor +from .._core import parse, utils +from ..ir import is_defined_in_class + +FType = _TypeVar("FType", bound=_Callable) + + +def function(f: FType) -> Union[Function, FType]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if is_defined_in_class(inspect.stack()): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(function, "dispatch_token", "relax") + + +class TensorProxy: + def __call__( + self, + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: str = None, + ndim: int = -1, + ) -> TensorType: + return tensor(shape, dtype, ndim) + + def __getitem__(self, keys) -> Var: + return self(*keys) # pylint: disable=no-member # type: ignore + + +Tensor = TensorProxy() # pylint: disable=invalid-name + + +class CallableProxy: + """Function type. + + A function type consists of a list of type parameters to enable + the definition of generic functions, + a set of type constraints which we omit for the time being, + a sequence of argument types, and a return type. + + We can informally write them as: + `forall (type_params), (arg_types) -> ret_type where type_constraints` + + Parameters + ---------- + arg_types : List[Type] + The argument types + + ret_type : Type + The return type. + + type_params : Optional[List[TypeVar]] + The type parameters + + type_constraints : Optional[List[TypeConstraint]] + The type constraints. + """ + + def __call__( + self, + arg_types: List[Type], + ret_type: Type, + type_params: Optional[List[TypeVar]] = None, + type_constraints: Optional[List[TypeConstraint]] = None, + ) -> FuncType: + def _convert_type(ty: Union[Type, TensorType]) -> Type: + if isinstance(ty, TensorType): + return ty.type + elif isinstance(ty, Type): + return ty + else: + raise TypeError(f"Expect a Type or TensorType, but got: {ty}") + + arg_types = [_convert_type(ty) for ty in arg_types] + ret_type = _convert_type(ret_type) + return FuncType(arg_types, ret_type, type_params, type_constraints) + + def __getitem__(self, keys) -> Var: + return self(*keys) # pylint: disable=no-member # type: ignore + + +Callable = CallableProxy() + + +class MatchShapePair: + value: Expr + pattern: List[PrimExpr] + + def __init__(self, value: Expr, pattern: List[PrimExpr]) -> None: + self.value = value + self.pattern = pattern + + +def match_shape(value: Expr, pattern: List[PrimExpr]): + return MatchShapePair(value, pattern) diff --git a/python/tvm/script/_parser/relax/parser.py b/python/tvm/script/_parser/relax/parser.py new file mode 100644 index 0000000000..f8101a6e6c --- /dev/null +++ b/python/tvm/script/_parser/relax/parser.py @@ -0,0 +1,357 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import contextlib +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union + +from tvm import relax, tir +from tvm.ir import Type +from tvm.script.ir_builder.relax.frame import BlockFrame + +from ...ir_builder import relax as R +from ...ir_builder.base import IRBuilder +from .._core import Parser, dispatch, doc +from .entry import MatchShapePair, Tensor, TensorType + + +class VarDefLoc: + def __init__(self, name: str, line: int, col: int): + self.name = name + self.line = line + self.col = col + + def __str__(self): + return f"{self.name}@{self.line}:{self.col}" + + def __repr__(self): + return f"{self.name}@{self.line}:{self.col}" + + +def collect_var_definitions(stmts: List[doc.stmt]) -> Dict[str, List[VarDefLoc]]: + class Collector(doc.NodeVisitor): + results: Dict[str, List[VarDefLoc]] + + def __init__(self): + self.results = defaultdict(list) + + def visit_Name(self, node: doc.Name): # pylint: disable=invalid-name + assert isinstance(node.ctx, doc.Store) + assert node.id + assert node.lineno is not None + assert node.col_offset is not None + self.results[node.id].append( + VarDefLoc( + node.id, + node.lineno, + node.col_offset, + ) + ) + + collector = Collector() + for stmt in stmts: + if isinstance(stmt, doc.Assign): + assert len(stmt.targets) == 1 + collector.visit(stmt.targets[0]) + elif isinstance(stmt, doc.AugAssign): + collector.visit(stmt.target) + + return collector.results + + +def bind_value_with_dataflow_var_names( + dataflow_var_names: List[str], var_def_table: Optional[Dict[str, List[VarDefLoc]]] +): + def bind_assign_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + var_table = self.var_table.get() + + if isinstance(value, tir.Var): + if value.name and var_name != value.name: + self.report_error( + node, + "Cannot define TIR variables with different names. The LHS of binding should " + "has the same name provided in RHS.", + ) + if var_name in var_table: + prev_value = var_table[var_name] + if not isinstance(prev_value, tir.Var): + self.report_error( + node, + "Cannot redefine a non-TIR-variable object to a TIR variable. Please " + "define the TIR variable with another name.", + ) + if prev_value.dtype != value.dtype: + self.report_error( + node, + "Expected the same dtype for TIR vars " + f"but got {value.dtype} vs {prev_value.dtype}", + ) + return prev_value + IRBuilder.name(var_name, value) + return value + + is_dataflow_var = False + if var_def_table is not None and ( + var_name not in dataflow_var_names or node.lineno != var_def_table[var_name][-1].line + ): + is_dataflow_var = True + + if isinstance(value, relax.Expr): + var = R.emit(value, is_dataflow_var) + # It's an internal check, so directly use assert here. + assert var is not None + IRBuilder.name(var_name, var) + return var + elif isinstance(value, MatchShapePair): + var = R.emit_match_shape( + value.value, value.pattern, emit_var=True, is_dataflow_var=is_dataflow_var + ) + # It's an internal check, so directly use assert here. + assert var is not None + IRBuilder.name(var_name, var) + return var + else: + raise TypeError(f"Unsupported type {type(value)} in assignment") + + return bind_assign_value + + +def eval_type_annotation(self: Parser, node: Union[doc.Expression, doc.expr]) -> Any: + type_annotation = self.eval_expr(node) + if callable(type_annotation): + type_annotation = Tensor() + if isinstance(type_annotation, TensorType): + shape = type_annotation.shape + if shape is None: + return type_annotation.type, None + shape = list(shape.values) + var_table = self.var_table.get() + for i, expr in enumerate(shape): + # Define the symbolic shape var + if isinstance(expr, tir.Var): + name = expr.name + if name in var_table: + shape[i] = var_table[name] + else: + self.var_table.add(name, shape[i]) + return type_annotation.type, relax.ShapeExpr(shape) + else: + if not isinstance(type_annotation, Type): + self.report_error(node, f"Unsupported type annotation {type(type_annotation)}") + return type_annotation, None + + +@dispatch.register(token="relax", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + with R.function(): + R.func_name(node.name) + if node.returns is not None: + ann_type, _ = eval_type_annotation(self, node.returns) + R.func_ret_type(ann_type) + with self.with_dispatch_token("relax"): + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="relax", type_name="pre_token_switch") +def pre_token_switch(self: Parser, node: doc.Expr) -> None: # pylint: disable=unused-argument + ir_builder = IRBuilder() + ir_builder.__enter__() + + +@dispatch.register(token="relax", type_name="post_token_switch") +def post_token_switch(self: Parser, node: doc.Expr) -> None: + ir_builder = IRBuilder.current() + result = ir_builder.get() + ir_builder.__exit__(None, None, None) + var = R.emit(result, is_dataflow_var=False) + IRBuilder.name(node.name, var) + self.var_table.add(node.name, var, allow_shadowing=False) + + +@dispatch.register(token="relax", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + value = self.eval_expr(node.value) + if isinstance(value, MatchShapePair): + R.emit_match_shape(value.value, value.pattern, emit_var=False, is_dataflow_var=False) + elif isinstance(value, tuple): + # Currently `res` must be the return value of `R.output`. In order to make these variables + # accessible to the bindings of following binding blocks, we should pop these variables into + # the variable table of one level higher. + for var_name in self.var_table.frames[-1].vars: + if self.var_table.name2value[var_name][-1] in value: + var = self.var_table.name2value[var_name][-1] + # Pop up the variable to the variable table one level higher. + if var_name in self.var_table.frames[-2].vars: + self.var_table.name2value[var_name][-2] = var + else: + self.var_table.frames[-2].add(var_name) + self.var_table.name2value[var_name].append(var) + elif value is not None: + self.report_error(node, f"Unsupported Expr stmt type {value}.") + + +@dispatch.register(token="relax", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param_type, param_shape = self.visit_tvm_annotation(arg.annotation) + param = R.arg(arg.arg, param_type, param_shape) + + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="relax", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr): + return eval_type_annotation(self, node) + + +@dispatch.register(token="relax", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + # Currently only `with R.dataflow()` is supported + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + if len(node.items) != 1: + self.report_error(node, "Only one dataflow block is allowed") + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, BlockFrame): + self.report_error( + item.context_expr, "Invalid context expression in the with-statement." + ) + stack.enter_context(frame) + if item.optional_vars is not None: + self.report_error( + item.context_expr, + "Relax syntax doesn't allow binding expressions in `with` to variables", + ) + + assert isinstance(node.body, list) + var_def_table = collect_var_definitions(node.body) + + if ( + not isinstance(node.body[-1], doc.Expr) + or not isinstance(node.body[-1].value, doc.Call) + or node.body[-1].value.func.attr != "output" + ): + self.report_error( + node.body[-1], + "Relax dataflow blocks must have output. However, the last statement inside a " + "dataflow block is not `R.output`. Please use `R.output` to specify the output of " + "the dataflow block.", + ) + + dataflow_var_names = [] + for arg in node.body[-1].value.args: + if not isinstance(arg, doc.Name): + self.report_error( + arg, + "The output of Relax dataflow blocks must be all variables. However, one of " + "the dataflow block output is not a variable. Please make sure all output are " + "variables.", + ) + dataflow_var_names.append(arg.id) + + for i in range(len(node.body) - 1): + if not isinstance(node.body[i], doc.Assign): + self.report_error( + node.body[i], + "One non-assign statement appears unexpectedly inside a dataflow block. Only " + "the last statement inside a dataflow block is an Expr. Please make sure this " + "statement appears at a correct position.", + ) + if len(node.body[i].targets) != 1: + self.report_error( + node.body[i], "Consequential assignments like 'a = b = c' are not supported." + ) + lhs = node.body[i].targets[0] + rhs = self.eval_expr(node.body[i].value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_value_with_dataflow_var_names(dataflow_var_names, var_def_table), + allow_shadowing=True, + ) + + self.visit(node.body[-1]) + + +@dispatch.register(token="relax", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_value_with_dataflow_var_names(dataflow_var_names=[], var_def_table=None), + allow_shadowing=True, + ) + + +@dispatch.register(token="relax", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + ann_type, ann_shape = self.visit_tvm_annotation(node.annotation) + self.eval_assign( + target=lhs, + source=rhs, + bind_value=bind_value_with_dataflow_var_names(dataflow_var_names=[], var_def_table=None), + allow_shadowing=True, + ) + var = self.var_table.get().get(lhs.id) + assert isinstance(var, relax.Var) + R.ir.annotate_type_shape(var, ann_type, ann_shape) + + +@dispatch.register(token="relax", type_name="Return") +def visit_return(self: Parser, node: doc.Assign) -> None: + value = self.eval_expr(node.value) + + if isinstance(value, relax.Expr): + R.func_ret_value(value) + elif isinstance(value, Tuple): + if all([isinstance(f, tir.PrimExpr) for f in value]): + R.func_ret_value(relax.ShapeExpr(value)) + elif any([isinstance(f, tir.PrimExpr) for f in value]): + self.report_error( + node, "Return types, with mixed PrimExpr and Relax Expr, is not supported." + ) + else: + R.func_ret_value(relax.Tuple(value)) + else: + self.report_error(node, f"Unsupported return value type {type(value)}.") + + +@dispatch.register(token="relax", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + if node.orelse is None: + raise ValueError("Else statements are required for relax dialect.") + with R.If(self.eval_expr(node.test)) as if_frame: + with self.var_table.with_frame(): + with R.Then(): + self.visit_body(node.body) + with self.var_table.with_frame(): + with R.Else(): + self.visit_body(node.orelse) + self.var_table.add(if_frame.var_name, if_frame.var, allow_shadowing=True) diff --git a/python/tvm/script/_parser/tir/__init__.py b/python/tvm/script/_parser/tir/__init__.py new file mode 100644 index 0000000000..930764f73d --- /dev/null +++ b/python/tvm/script/_parser/tir/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from ...ir_builder.tir import * # pylint: disable=redefined-builtin +from ...ir_builder.tir import ir as _tir +from . import operation as _operation +from . import parser as _parser +from .entry import Buffer, Ptr, prim_func + +__all__ = _tir.__all__ + ["Buffer", "Ptr", "prim_func"] diff --git a/python/tvm/script/_parser/tir/entry.py b/python/tvm/script/_parser/tir/entry.py new file mode 100644 index 0000000000..07bd75f351 --- /dev/null +++ b/python/tvm/script/_parser/tir/entry.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import inspect +from typing import Callable, Union + +from tvm.tir import Buffer, PrimFunc + +from ...ir_builder.tir import buffer_decl, ptr +from .._core import parse, utils +from ..ir import is_defined_in_class + + +def prim_func(f: Callable) -> Union[PrimFunc, Callable]: + if not inspect.isfunction(f): + raise TypeError(f"Expect a function, but got: {f}") + if is_defined_in_class(inspect.stack()): + return f + return parse(f, utils.inspect_function_capture(f)) + + +setattr(prim_func, "dispatch_token", "tir") + + +class BufferProxy: + def __call__( + self, + shape, + dtype="float32", + data=None, + strides=None, + elem_offset=None, + scope="global", + align=0, + offset_factor=0, + buffer_type="", + axis_separators=None, + ) -> Buffer: + return buffer_decl( + shape, + dtype=dtype, + data=data, + strides=strides, + elem_offset=elem_offset, + scope=scope, + align=align, + offset_factor=offset_factor, + buffer_type=buffer_type, + axis_separators=axis_separators, + ) + + def __getitem__(self, keys) -> Buffer: + if not isinstance(keys, tuple): + return self(keys) + if len(keys) >= 2 and not isinstance(keys[1], str): + return self(keys) + return self(*keys) # pylint: disable=no-member # type: ignore + + +class PtrProxy: + def __call__(self, dtype, storage_scope="global"): + if callable(dtype): + dtype = dtype().dtype + return ptr(dtype, storage_scope) # pylint: disable=no-member # type: ignore + + def __getitem__(self, keys): + if not isinstance(keys, tuple): + return self(keys) + return self(*keys) + + +Buffer = BufferProxy() # pylint: disable=invalid-name +Ptr = PtrProxy() # pylint: disable=invalid-name diff --git a/python/tvm/script/_parser/tir/operation.py b/python/tvm/script/_parser/tir/operation.py new file mode 100644 index 0000000000..87fb9406ae --- /dev/null +++ b/python/tvm/script/_parser/tir/operation.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import Type + +from tvm import tir +from tvm.tir import IntImm + +from .._core import OpMethod, doc, register_op + + +def _register_expr_op(ty: Type): # pylint: disable=invalid-name + ty._dispatch_type = ty # pylint: disable=protected-access + + def _and(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + return tir.And(a, b) + + def _or(a, b): + if isinstance(a, bool): + a = IntImm("bool", a) + if isinstance(b, bool): + b = IntImm("bool", b) + return tir.Or(a, b) + + def r(op: Type, i: int, m: OpMethod): # pylint: disable=invalid-name + register_op(ty, op, i)(m) + + for i in [0, 1]: + # Case 1. binop + r(doc.Add, i, lambda a, b: a + b) + r(doc.Sub, i, lambda a, b: a - b) + r(doc.Mult, i, lambda a, b: a * b) + r(doc.Div, i, lambda a, b: a / b) + r(doc.FloorDiv, i, lambda a, b: a // b) + r(doc.Mod, i, lambda a, b: a % b) + r(doc.LShift, i, lambda a, b: a << b) + r(doc.RShift, i, lambda a, b: a >> b) + r(doc.BitOr, i, lambda a, b: a | b) + r(doc.BitXor, i, lambda a, b: a ^ b) + r(doc.BitAnd, i, lambda a, b: a & b) + # doc.MatMult <-- not implemented + # doc.Pow <-- not implemented + # Case 2. cmpop + r(doc.Eq, i, tir.EQ) + r(doc.NotEq, i, tir.NE) + r(doc.Lt, i, tir.LT) + r(doc.LtE, i, tir.LE) + r(doc.Gt, i, tir.GT) + r(doc.GtE, i, tir.GE) + # doc.Is <-- not implemented + # doc.IsNot <-- not implemented + # doc.In <-- not implemented + # doc.NotIn <-- not implemented + # Case 3. boolop + r(doc.And, i, _and) + r(doc.Or, i, _or) + for i in [0]: + # Case 4. unaryop + r(doc.Invert, i, lambda a: ~a) + r(doc.Not, i, tir.Not) + r(doc.UAdd, i, lambda a: +a) + r(doc.USub, i, lambda a: -a) + + +_register_expr_op(tir.PrimExpr) +_register_expr_op(tir.IterVar) diff --git a/python/tvm/script/_parser/tir/parser.py b/python/tvm/script/_parser/tir/parser.py new file mode 100644 index 0000000000..032555187f --- /dev/null +++ b/python/tvm/script/_parser/tir/parser.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import contextlib +from functools import partial +from typing import Any + +from tvm.ir import PrimType +from tvm.tir import Buffer, IterVar, PrimExpr, Var + +from ...ir_builder import tir as T +from ...ir_builder.base import IRBuilder +from ...ir_builder.base import IRBuilderFrame as Frame +from .._core import Parser, dispatch, doc + + +def bind_with_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, (Buffer, Var)): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in with statement") + raise NotImplementedError + + +def bind_for_value(self: Parser, node: doc.expr, var_name: str, value: Any) -> Any: + if isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Var): + IRBuilder.name(var_name, value) + return value + else: + self.report_error(node, f"Do not know how to bind type: {type(value)} in for statement") + raise NotImplementedError + + +def bind_assign_value(self: Parser, _node: doc.expr, var_name: str, value: Any) -> Any: + if isinstance(value, T.inline): + return value.value + elif isinstance(value, (list, tuple)): + for i, v in enumerate(value): + bind_with_value(self, _node, f"{var_name}_{i}", v) + return value + elif isinstance(value, Frame): + value.add_callback(partial(value.__exit__, None, None, None)) + res = value.__enter__() + IRBuilder.name(var_name, res) + return res + elif isinstance(value, (Buffer, IterVar)) or ( + isinstance(value, Var) and not self.var_table.exist(value) + ): + IRBuilder.name(var_name, value) + return value + elif isinstance(value, PrimExpr): + var = T.var(value.dtype) + IRBuilder.name(var_name, var) + frame = T.let(var, value) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + return var + return value + + +@dispatch.register(token="tir", type_name="For") +def visit_for(self: Parser, node: doc.For) -> None: + for_frame = self.eval_expr(node.iter) + if not isinstance(for_frame, T.frame.ForFrame): + self.report_error( + node.iter, + "Expect the for loop to be one of the following: " + "range, T.serial, T.grid, T.parallel, T.vectorized, T.unroll, T.thread_binding", + ) + with self.var_table.with_frame(): + with for_frame as iters: + self.eval_assign(target=node.target, source=iters, bind_value=bind_for_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="While") +def visit_while(self: Parser, node: doc.While) -> None: + with self.var_table.with_frame(): + cond = self.eval_expr(node.test) + with T.While(cond): + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="Assign") +def visit_assign(self: Parser, node: doc.Assign) -> None: + if len(node.targets) != 1: + self.report_error(node, "Consequential assignments like 'a = b = c' are not supported.") + lhs = node.targets[0] + rhs = self.eval_expr(node.value) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AugAssign") +def visit_aug_assign(self: Parser, node: doc.AugAssign) -> None: + lhs_pos = ( + node.target.lineno, + node.target.col_offset, + node.target.end_lineno, + node.target.end_col_offset, + ) + rhs_pos = ( + node.value.lineno, + node.value.col_offset, + node.value.end_lineno, + node.value.end_col_offset, + ) + node.target.ctx = doc.Load(*lhs_pos) + with self.var_table.with_frame(): + lhs_name = "__tvm_tmp_value_aug_assign_lhs" + rhs_name = "__tvm_tmp_value_aug_assign_rhs" + lhs_expr = self.eval_expr(node.target) + rhs_expr = self.eval_expr(node.value) + self.var_table.add(lhs_name, lhs_expr) + self.var_table.add(rhs_name, rhs_expr) + op = doc.BinOp( + doc.Name(lhs_name, doc.Load(*lhs_pos), *lhs_pos), + node.op, + doc.Name(rhs_name, doc.Load(*rhs_pos), *rhs_pos), + *lhs_pos, + ) + rhs = self.eval_expr(op) + lhs = node.target + lhs.ctx = doc.Store(*lhs_pos) + if isinstance(lhs, doc.Subscript): + if isinstance(lhs.slice, doc.Tuple): + indices = [] + for index in lhs.slice.elts: + indices.append(self.eval_expr(index)) + else: + indices = [self.eval_expr(lhs.slice)] + T.buffer_store(self.eval_expr(lhs.value), rhs, indices) + else: + self.eval_assign(target=lhs, source=rhs, bind_value=bind_assign_value) + + +@dispatch.register(token="tir", type_name="AnnAssign") +def visit_ann_assign(self: Parser, node: doc.AnnAssign) -> None: + lhs = node.target + rhs = self.eval_expr(node.value) + ann_var = self.visit_tvm_annotation(node.annotation) + if not isinstance(ann_var, Var): + self.report_error(node.annotation, "Annotation should be Var") + self.eval_assign(target=lhs, source=ann_var, bind_value=bind_assign_value) + frame = T.let(ann_var, rhs) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="With") +def visit_with(self: Parser, node: doc.With) -> None: + with contextlib.ExitStack() as stack: + stack.enter_context(self.var_table.with_frame()) + for item in node.items: + frame = self.eval_expr(item.context_expr) + if not isinstance(frame, Frame): + self.report_error( + item.context_expr, "Invalid context expression in the with-statement." + ) + rhs = stack.enter_context(frame) + if item.optional_vars is not None: + self.eval_assign(target=item.optional_vars, source=rhs, bind_value=bind_with_value) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="FunctionDef") +def visit_function_def(self: Parser, node: doc.FunctionDef) -> None: + with self.var_table.with_frame(): + self.var_table.add("range", T.serial) + with T.prim_func(): + T.func_name(node.name) + if node.returns is not None: + ret_type = self.eval_expr(node.returns) + if callable(ret_type): + ret_type = PrimType(ret_type().dtype) + T.func_ret(ret_type) + with self.with_dispatch_token("tir"): + self.visit(node.args) + self.visit_body(node.body) + + +@dispatch.register(token="tir", type_name="arguments") +def visit_arguments(self: Parser, node: doc.arguments) -> None: + # TODO: handle different types of arguments: + # - vararg: arg | None + # - kwonlyargs: list[arg] + # - kw_defaults: list[expr | None] + # - kwarg: arg | None + # - defaults: list[expr] + # - posonlyargs: list[arg] + arg: doc.arg + for arg in node.args: + if arg.annotation is None: + self.report_error(arg, "Type annotation is required for function parameters.") + param = T.arg(arg.arg, self.visit_tvm_annotation(arg.annotation)) + self.var_table.add(arg.arg, param) + + +@dispatch.register(token="tir", type_name="tvm_annotation") +def visit_tvm_annotation(self: Parser, node: doc.expr): + annotation = self.eval_expr(node) + if callable(annotation): + annotation = annotation() + return annotation + + +@dispatch.register(token="tir", type_name="Expr") +def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: + res = self.eval_expr(node.value) + if isinstance(res, Frame): + res.add_callback(partial(res.__exit__, None, None, None)) + res.__enter__() + + +@dispatch.register(token="tir", type_name="If") +def visit_if(self: Parser, node: doc.If) -> None: + with self.var_table.with_frame(): + with T.If(self.eval_expr(node.test)): + with T.Then(): + self.visit_body(node.body) + if node.orelse: + with T.Else(): + self.visit_body(node.orelse) + + +@dispatch.register(token="tir", type_name="Assert") +def visit_assert(self: Parser, node: doc.Assert) -> None: + cond = self.eval_expr(node.test) + msg = self.eval_expr(node.msg) + frame = T.Assert(cond, msg) + frame.add_callback(partial(frame.__exit__, None, None, None)) + frame.__enter__() + + +@dispatch.register(token="tir", type_name="Return") +def visit_return(self: Parser, node: doc.Return) -> None: + self.report_error(node, "Return is not allowed.") diff --git a/python/tvm/script/ir_builder/base.py b/python/tvm/script/ir_builder/base.py index 7aa33ee49c..69f13b2145 100644 --- a/python/tvm/script/ir_builder/base.py +++ b/python/tvm/script/ir_builder/base.py @@ -64,8 +64,10 @@ def __enter__(self) -> "IRBuilderFrame": _ffi_api.IRBuilderFrameEnter(self) # type: ignore[attr-defined] # pylint: disable=no-member return self - def __exit__(self, ptype, value, trace) -> None: # pylint: disable=unused-argument - _ffi_api.IRBuilderFrameExit(self) # type: ignore[attr-defined] # pylint: disable=no-member + def __exit__(self, exc_type, exc_value, trace) -> None: # pylint: disable=unused-argument + if exc_type is None and exc_value is None: + # Do not execute `FrameExit` if the with scope exits because of exceptions + _ffi_api.IRBuilderFrameExit(self) # pylint: disable=no-member # type: ignore def add_callback(self, callback: Callable[[], None]) -> None: """Add a callback method invoked when exiting the with-scope. diff --git a/python/tvm/script/ir_builder/ir/__init__.py b/python/tvm/script/ir_builder/ir/__init__.py index ebb9728737..946be263a7 100644 --- a/python/tvm/script/ir_builder/ir/__init__.py +++ b/python/tvm/script/ir_builder/ir/__init__.py @@ -16,4 +16,4 @@ # under the License. """Package tvm.script.ir_builder.ir""" from .frame import IRModuleFrame -from .ir import ir_module +from .ir import decl_function, def_function, ir_module diff --git a/python/tvm/script/ir_builder/ir/ir.py b/python/tvm/script/ir_builder/ir/ir.py index 213180463c..ac7d479e1a 100644 --- a/python/tvm/script/ir_builder/ir/ir.py +++ b/python/tvm/script/ir_builder/ir/ir.py @@ -16,9 +16,46 @@ # under the License. """Package tvm.script.ir_builder.ir.ir""" +from tvm.ir import BaseFunc, GlobalVar + from . import _ffi_api from .frame import IRModuleFrame def ir_module() -> IRModuleFrame: + """Start a ir_module frame. + Returns + ------- + frame: IRModuleFrame + The constructed frame. + """ return _ffi_api.IRModule() # type: ignore[attr-defined] # pylint: disable=no-member + + +def decl_function(func_name: str) -> GlobalVar: + """Declare a Function without given the specific function implementation. + Parameters + ---------- + func_name : str + The function unique name. + Note + ---- + It is usually used in cross-function call. And we can specify the function by `DefFunction` + Returns + ------- + gv : GlobalVar + The corresponding GlobalVar. + """ + return _ffi_api.DeclFunction(func_name) # pylint: disable=no-member # type: ignore + + +def def_function(func_name: str, func: BaseFunc) -> None: + """Define the function which is declared before. + Parameters + ---------- + func_name : str + The function unique name. + func: BaseFunc + The given function implementation + """ + return _ffi_api.DefFunction(func_name, func) # pylint: disable=no-member # type: ignore diff --git a/python/tvm/script/ir_builder/relax/__init__.py b/python/tvm/script/ir_builder/relax/__init__.py new file mode 100644 index 0000000000..f0905acf34 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-import +"""Package tvm.script.ir_builder.relax""" +from . import frame +from .ir import * # pylint: disable=wildcard-import,redefined-builtin diff --git a/python/tvm/script/ir_builder/relax/_ffi_api.py b/python/tvm/script/ir_builder/relax/_ffi_api.py new file mode 100644 index 0000000000..6e2098cf88 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.ir_builder.relax""" +import tvm._ffi + +tvm._ffi._init_api("script.ir_builder.relax", __name__) # pylint: disable=protected-access diff --git a/python/tvm/script/ir_builder/relax/frame.py b/python/tvm/script/ir_builder/relax/frame.py new file mode 100644 index 0000000000..97e181fbe4 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/frame.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""IR Builder Frame for Relax dialect""" +from tvm._ffi import register_object as _register_object + +from ..base import IRBuilderFrame + + +@_register_object("script.ir_builder.relax.RelaxFrame") +class RelaxFrame(IRBuilderFrame): + """The base ir_builder frame for the relax dialect.""" + + +@_register_object("script.ir_builder.relax.SeqExprFrame") +class SeqExprFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.FunctionFrame") +class FunctionFrame(SeqExprFrame): + """The ir_builder frame for the relax function.""" + + +@_register_object("script.ir_builder.relax.BlockFrame") +class BlockFrame(RelaxFrame): + """The ir_builder frame for relax binding blocks.""" + + +@_register_object("script.ir_builder.relax.IfFrame") +class IfFrame(RelaxFrame): + ... + + +@_register_object("script.ir_builder.relax.ThenFrame") +class ThenFrame(SeqExprFrame): + ... + + +@_register_object("script.ir_builder.relax.ElseFrame") +class ElseFrame(SeqExprFrame): + ... diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py new file mode 100644 index 0000000000..fcefc54bd1 --- /dev/null +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -0,0 +1,369 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, wrong-import-order +"""IRBuilder for Relax dialect""" + +from typing import Dict, List, Optional, Tuple, Union + +from tvm._ffi import register_object as _register_object +from tvm.ir import Attrs, Type +from tvm.relax import Call, Expr, ExternFunc, ShapeExpr, Var + +############################### Operators ############################### +from tvm.relax.op import ( + add, + builtin, + call_tir, + invoke_closure, + make_closure, + multiply, + print, + shape_of, + unique, +) +from tvm.relax.ty import ObjectType, ShapeType +from tvm.runtime import Object as tvm_Object +from tvm.tir import PrimExpr + +from ..tir import var as _tir_var +from . import _ffi_api, frame + +############################## Tensor Type ############################## + + +@_register_object("script.ir_builder.relax.TensorType") +class TensorType(tvm_Object): + """A temporary Tensor type for `R.Tensor` in ir_builder.""" + + +def tensor( + shape: Optional[List[Union[PrimExpr, str]]] = None, + dtype: Optional[str] = None, + ndim: int = -1, +): + """Helper function for `R.Tensor` in parser + Parameters + ---------- + shape: Optional[List[Union[PrimExpr, str]]] + The shape of the tensor. It's runtime dependent if `shape` is None. + dtype: Optional[str] + The element data type of the tensor. It's runtime dependent if `dtype` is None. + ndim: int + The number of dimensions of the tensor. It's runtime dependent if `ndim` is -1. + Returns + ------- + tensor_type: TensorType + The TensorType that is only used in ir_builder. + """ + + if shape is not None: + if not isinstance(shape, list): + shape = list(shape) + + for i, s in enumerate(shape): + if isinstance(s, str): + shape[i] = _tir_var("int64", s) + + return _ffi_api.Tensor(shape, dtype, ndim) # pylint: disable=no-member # type: ignore + + +############################## Other Types ############################## + +Object = ObjectType() # pylint: disable=invalid-name +Shape = ShapeType() # pylint: disable=invalid-name + +############################### Function ################################ + + +def function() -> frame.FunctionFrame: + """Start a function frame. + Returns + ------- + frame: FunctionFrame + The constructed function frame. + """ + return _ffi_api.Function() # pylint: disable=no-member # type: ignore + + +def arg(name: str, type: Union[Type, TensorType], shape: Optional[ShapeExpr] = None) -> Var: + """Add a parameter to the last function frame. + Parameters + ---------- + name: str + The name of the parameter. + type: Union[Type, TensorType] + The type of the parameter. It can be a typical TVM Type or a TensorType, + which contains both type and shape + shape: Optional[ShapeExpr] + The shape of the parameter. + Returns + ------- + var: Var + The created function parameter var. + """ + + if isinstance(type, TensorType): + if shape is not None: + raise ValueError("Cannot specify the shape if we use TensorType") + shape = type.shape + type = type.type + + return _ffi_api.Arg(name, type, shape) # pylint: disable=no-member # type: ignore + + +def func_name(name: str) -> None: + """Specify the name of the last function frame. + Parameters + ---------- + name: str + The function name. + """ + return _ffi_api.FuncName(name) # pylint: disable=no-member # type: ignore + + +def func_attr(attrs: Dict[str, tvm_Object]) -> None: + """Specify the attrs of the last function frame. + Parameters + ---------- + attrs: Dict[str, Object] + The function attrs. + """ + return _ffi_api.FuncAttrs(attrs) # pylint: disable=no-member # type: ignore + + +def func_ret_type(ret_type: Union[TensorType, Type]) -> None: + """Specify the return type of the last function frame. + Parameters + ---------- + ret_type: Union[TensorType, Type] + The function return type. + """ + if isinstance(ret_type, TensorType): + ret_type = ret_type.type + return _ffi_api.FuncRetType(ret_type) # pylint: disable=no-member # type: ignore + + +def func_ret_value(value: Expr) -> None: + """Specify the return value of the last function frame. + Parameters + ---------- + value: Expr + The function return value. + """ + return _ffi_api.FuncRetValue(value) # pylint: disable=no-member # type: ignore + + +############################# BindingBlock ############################## + + +def dataflow() -> frame.BlockFrame: + """Start a dataflow binding block frame. + Returns + ------- + frame: frame.BlockFrame + The created ir_builder Block frame. + """ + return _ffi_api.Dataflow() # pylint: disable=no-member # type: ignore + + +def output(*vars: Tuple[Var]) -> Tuple[Var]: + """Expose the dataflow block output variables as global ones. + Parameters + ---------- + vars: Tuple[Var] + The output variables of a dataflow block. + Returns + ------- + vars: Tuple[Var] + The output variables of a dataflow block. Return the input variables to parser side for + followup process + """ + _ffi_api.DataflowBlockOutput(vars) # pylint: disable=no-member # type: ignore + return vars + + +################################## Ops ################################# + + +def call_packed( + func: str, + *args: List[Expr], + attrs: Optional[Attrs] = None, + type_args: Optional[Union[TensorType, List[TensorType]]] = None, +) -> Call: + """Create a relax Call, which calls a packed function. + Parameters + ---------- + func: str + The name of extern function. + args : List[Expr] + The arguments. + attrs: Optional[Attrs] + The call attributes + type_args: Optional[Union[TensorType, List[TensorType]]] + List of Types + Returns + ------- + call: Call + The created Relax Call + """ + op = ExternFunc(func) + if type_args is None: + raise ValueError(f"R.call_packed is required to have type_args") + if isinstance(type_args, (TensorType, Type)): + type_args = [type_args] + elif isinstance(type_args, tuple): + type_args = list(type_args) + for i, argument in enumerate(type_args): + if isinstance(argument, TensorType): + type_args[i] = argument.type + elif isinstance(argument, Type): + type_args[i] = argument + else: + raise TypeError( + "call_packed `type_args` is expected to be list of TensorType/Type, " + f"but got {type(arg)}" + ) + + return Call(op, args, attrs=attrs, type_args=type_args) + + +############################### Bindings ############################### + + +def emit(value: Expr, is_dataflow_var: bool) -> Var: + """Emit a binding to the last binding block frame. + Parameters + ---------- + value: Expr + The right side value of the bindings to be emitted. + is_dataflow_var: bool + A boolean indicating if the emitted binding variable is a dataflow variable. + Returns + ------- + var: Var + The left side var of the emitted binding. + """ + return _ffi_api.Emit(value, is_dataflow_var) # pylint: disable=no-member # type: ignore + + +def emit_match_shape( + value: Expr, pattern: List[PrimExpr], emit_var: bool, is_dataflow_var: bool +) -> Optional[Var]: + """Emit a match_shape binding to the last binding block frame. + Parameters + ---------- + value: Expr + The value of the MatchShape to be emitted. + pattern: List[PrimExpr] + The pattern of the MatchShape to be emitted. + emit_var: bool + A boolean indicating if the MatchShape contains the emitted variable. + is_dataflow_var: bool + A boolean indicating if the emitted variable is a dataflow variable when `emit_var` is True. + When `emit_var` is False, the value of this flag will be ignored. + Returns + ------- + var: Optional[Var] + The emitted var if `emit_var` is True. Otherwise, return `None`. + """ + return _ffi_api.EmitMatchShape(value, pattern, emit_var, is_dataflow_var) # type: ignore + + +############################# Type Deduce ############################## + + +def annotate_type_shape(var: Var, type: Type, shape: ShapeExpr) -> None: + """Annotate and check the type of relax var. + Parameters + ---------- + var: Var + The input var to be annotated. + type: Type + The given type + shape: ShapeExpr + The given shape + """ + _ffi_api.AnnotateTypeShape(var, type, shape) + + +def If(condition: Expr) -> frame.IfFrame: # pylint: disable=invalid-name + """Create an if frame. + Parameters + ---------- + condition : Expr + The condition of if statement, executes the true branch if the condition is true, + otherwise jump into the false branch. + Returns + ------- + res : frame.IfFrame + The result IfFrame. + """ + return _ffi_api.If(condition) # pylint: disable=no-member # type: ignore + + +def Then() -> frame.ThenFrame: # pylint: disable=invalid-name + """Create a then frame. + Returns + ------- + res : frame.ThenFrame + The result ThenFrame. + """ + return _ffi_api.Then() # pylint: disable=no-member # type: ignore + + +def Else() -> frame.ElseFrame: # pylint: disable=invalid-name + """Create an else frame. + Returns + ------- + res : frame.ElseFrame + The result ElseFrame. + """ + return _ffi_api.Else() # pylint: disable=no-member # type: ignore + + +############################### Importer ############################### + +__all__ = [ + "Else", + "If", + "Object", + "Shape", + "TensorType", + "Then", + "add", + "arg", + "builtin", + "call_packed", + "call_tir", + "dataflow", + "emit", + "emit_match_shape", + "func_attr", + "func_name", + "func_ret_type", + "func_ret_value", + "function", + "invoke_closure", + "make_closure", + "multiply", + "output", + "print", + "unique", + "shape_of", + "tensor", +] diff --git a/python/tvm/script/parser/core/diagnostics.py b/python/tvm/script/parser/core/diagnostics.py index ad7ae50347..2767a97f60 100644 --- a/python/tvm/script/parser/core/diagnostics.py +++ b/python/tvm/script/parser/core/diagnostics.py @@ -220,7 +220,7 @@ def _emit(self, node: doc.AST, message: str, level: diagnostics.DiagnosticLevel) level : diagnostics.DiagnosticLevel The diagnostic level. """ - lineno = node.lineno or self.source.start_line + lineno = node.lineno or 1 col_offset = node.col_offset or self.source.start_column end_lineno = node.end_lineno or lineno end_col_offset = node.end_col_offset or col_offset diff --git a/python/tvm/script/relax/__init__.py b/python/tvm/script/parser_v1/relax/__init__.py similarity index 100% rename from python/tvm/script/relax/__init__.py rename to python/tvm/script/parser_v1/relax/__init__.py diff --git a/python/tvm/script/relax/function.py b/python/tvm/script/parser_v1/relax/function.py similarity index 100% rename from python/tvm/script/relax/function.py rename to python/tvm/script/parser_v1/relax/function.py diff --git a/python/tvm/script/relax/parser.py b/python/tvm/script/parser_v1/relax/parser.py similarity index 99% rename from python/tvm/script/relax/parser.py rename to python/tvm/script/parser_v1/relax/parser.py index 00b2f76dd2..976bdd55a8 100644 --- a/python/tvm/script/relax/parser.py +++ b/python/tvm/script/parser_v1/relax/parser.py @@ -23,22 +23,22 @@ import inspect import json from enum import Enum -from typing import Union, Dict, List, Tuple, Optional, Callable, Any -import synr -from synr import ast, Transformer +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import synr import tvm -from tvm import relay, relax, tir -from tvm.relax.utils import metadata_partitioner import tvm.script +from synr import Transformer, ast +from tvm import relax, relay, tir from tvm.ir import diagnostics from tvm.ir.module import IRModule -from tvm.script.tir.node import BufferSlice -import tvm.script.tir as tir_namespace -import tvm.script.relax as relax_namespace +from tvm.relax.utils import metadata_partitioner +from .. import relax as relax_namespace +from .. import tir as tir_namespace from ..parser import TVMScriptParser as _TIRScriptParser -from ..utils import tvm_span_from_synr, call_with_error_reporting +from ..tir.node import BufferSlice +from ..utils import call_with_error_reporting, tvm_span_from_synr def _is_registered(op_name: str, op_set=None) -> bool: diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index 8d81e374b2..5ceca53a11 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -54,6 +54,10 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { // TODO(@altanh): CopyOnWrite Expr VisitExpr(const Expr& expr) { + // TODO(relax-team): generalize prim_func support + if (expr->IsInstance()) { + return expr; + } Optional post = expr_memo_.Get(expr); if (post) { ICHECK(post.as()) << "memoized expressions should map to variables"; @@ -229,6 +233,7 @@ class BlockBuilderNode::ExprNormalizer : public ExprFunctor { new_false.same_as(op->false_branch)) { return GetRef(op); } + // TODO(relax-team): fix type/shape deduction for if node. return If(new_cond, new_true, new_false); } diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 39b8a0b9a5..163f9ddbf5 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -110,9 +110,9 @@ RELAY_REGISTER_OP("relax.print") .set_attr("FInferType", ReturnVoidType) .set_attr("FCallPacked", "relax.run.print"); -Expr MakePrint(Array vals, std::string format_str) { +Expr MakePrint(Array vals, std::string format) { auto attrs = make_object(); - attrs->format = format_str; + attrs->format = format; static const Op& op = Op::Get("relax.print"); return Call(op, vals, Attrs(attrs)); } @@ -229,9 +229,12 @@ RELAY_REGISTER_OP("relax.builtin.alloc_tensor") .set_attr("FInferShape", InferShapeAllocTensor) .set_attr("FInferType", InferTypeAllocTensor); -Expr MakeAllocTensor(Expr shape) { +Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->runtime_device_index = std::move(runtime_device_index); static const Op& op = Op::Get("relax.builtin.alloc_tensor"); - return Call(op, {shape}, {}, {}); + return Call(op, {shape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); @@ -244,9 +247,12 @@ RELAY_REGISTER_OP("relax.vm.builtin.alloc_storage") .add_argument("size", "Expr", "The size of the storage to allocate.") .set_attr("FInferType", ReturnObjectType); -Expr MakeVMAllocStorage(Expr size) { +Expr MakeVMAllocStorage(Expr size, DataType dtype, int64_t runtime_device_index) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->runtime_device_index = std::move(runtime_device_index); static const Op& op = Op::Get("relax.vm.builtin.alloc_storage"); - return Call(op, {size}, {}, {}); + return Call(op, {size}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_storage").set_body_typed(MakeVMAllocStorage); @@ -274,9 +280,12 @@ RELAY_REGISTER_OP("relax.vm.builtin.alloc_tensor") .set_attr("FInferShape", InferShapeVMAllocTensor) .set_attr("FInferType", InferTypeVMAllocTensor); -Expr MakeVMAllocTensor(Expr storage, Expr shape) { +Expr MakeVMAllocTensor(Expr storage, Expr shape, DataType dtype, int64_t runtime_device_index) { + auto attrs = make_object(); + attrs->dtype = std::move(dtype); + attrs->runtime_device_index = std::move(runtime_device_index); static const Op& op = Op::Get("relax.vm.builtin.alloc_tensor"); - return Call(op, {storage, shape}, {}, {}); + return Call(op, {storage, shape}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.vm.builtin.alloc_tensor").set_body_typed(MakeVMAllocTensor); @@ -290,9 +299,11 @@ RELAY_REGISTER_OP("relax.vm.builtin.store_shape") .add_argument("heap", "Expr", "The heap to store the shape.") .set_attr("FInferType", ReturnVoidType); -Expr MakeStoreShape(Expr shape, Expr heap) { +Expr MakeStoreShape(Expr shape, Expr heap, Array indices) { + auto attrs = make_object(); + attrs->indices = std::move(indices); static const Op& op = Op::Get("relax.vm.builtin.store_shape"); - return Call(op, {shape, heap}, {}, {}); + return Call(op, {shape, heap}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.vm.builtin.store_shape").set_body_typed(MakeStoreShape); @@ -305,9 +316,11 @@ RELAY_REGISTER_OP("relax.vm.builtin.load_shape") .add_argument("heap", "Expr", "The heap to load the shape from.") .set_attr("FInferType", ReturnShapeType); -Expr MakeLoadShape(Expr heap) { +Expr MakeLoadShape(Expr heap, Array indices) { + auto attrs = make_object(); + attrs->indices = std::move(indices); static const Op& op = Op::Get("relax.vm.builtin.load_shape"); - return Call(op, {heap}, {}, {}); + return Call(op, {heap}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relax.op.vm.builtin.load_shape").set_body_typed(MakeLoadShape); diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc index 2032f285eb..228de3ae8c 100644 --- a/src/relax/op/tensor/unary.cc +++ b/src/relax/op/tensor/unary.cc @@ -46,7 +46,7 @@ Expr MakeUnique(Expr data, bool sorted, bool return_inverse, bool return_counts, attrs->return_inverse = return_inverse; attrs->return_counts = return_counts; attrs->dim = dim; - static const Op& op = Op::Get("unique"); + static const Op& op = Op::Get("relax.unique"); return Call(op, {data}, Attrs(attrs)); } diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc index 290008f62e..6a813d2151 100644 --- a/src/relax/transform/fuse_ops.cc +++ b/src/relax/transform/fuse_ops.cc @@ -29,9 +29,11 @@ #include #include +#include #include "../../relay/analysis/graph_partitioner.h" #include "../../support/arena.h" + namespace tvm { namespace relax { diff --git a/src/relax/transform/meta_schedule.cc b/src/relax/transform/meta_schedule.cc index c0f9b98bfa..b92e8a509b 100644 --- a/src/relax/transform/meta_schedule.cc +++ b/src/relax/transform/meta_schedule.cc @@ -21,131 +21,46 @@ * \file tvm/relax/transform/meta_schedule.cc * \brief Pass for meta_schedule tuning */ +#include #include #include #include +#include "../../printer/text_printer.h" + namespace tvm { namespace relax { -class MetaScheduleTuner { - public: - explicit MetaScheduleTuner(Target target, Array config, String work_dir) - : target_(target), config_(config), work_dir_(work_dir) { - candgen_func_ = runtime::Registry::Get("relax.tuning_api.default_generate_candidate"); - ICHECK(candgen_func_) << "Default candidate generation function is not found."; - } - - // TODO(@sunggg): Currently, only supports basic arguments. - IRModule TuneIRMod(IRModule mod, transform::PassContext ctx) { - Trace trace = Downcast(ctx->GetCurrentTrace()); - ctx->PopTrace(); - Choice choice("meta_schedule.tune_relax_irmod_with_tuning_api", {target_, config_, work_dir_}, - "relax.tuning_api.Choice.default_constr_func", {}); - Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); - Array candidates = (*candgen_func_)(Array({knob}), trace); - ICHECK(candidates.size() == 1); - Trace best_trace = candidates[0]; - ctx->PushTrace(best_trace); - return best_trace->out_mod; - } - - // TODO(@sunggg): Currently, only supports basic arguments. - tir::PrimFunc TuneTIR(tir::PrimFunc f, transform::PassContext ctx) { - auto parse_mod_func = runtime::Registry::Get("tvm.meta_schedule.tune.parse_mod"); - ICHECK(parse_mod_func) << "Parse function is not found."; - // TODO(@sunggg): Whenever we tune tir, assume we start a new trace w/o pushing to the trace - // stack. Revisit later when we collect more usecases. - Trace trace = Trace((*parse_mod_func)(f), {}, {}); - - Choice choice("meta_schedule.tune_tir_with_tuning_api", {target_, config_, work_dir_}, - "relax.tuning_api.Choice.default_constr_func", {}); - Knob knob("meta_schedule.tune_irmod", {{"0", choice}}); - Array candidates = (*candgen_func_)(Array({knob}), trace); - ICHECK(candidates.size() == 1); - Trace best_trace = candidates[0]; - auto gvars = best_trace->out_mod->GetGlobalVars(); - ICHECK(gvars.size() == 1); - auto new_func = best_trace->out_mod->functions[gvars[0]]; - ICHECK(new_func->IsInstance()); - return Downcast(new_func); - } - - private: - Target target_; - Array config_; - String work_dir_; - const runtime::PackedFunc* candgen_func_; -}; +namespace transform { -class MetaScheduleAHB { - public: - explicit MetaScheduleAHB(const tvm::meta_schedule::Database& db, Target target) - : db_(db), target_(target) {} - IRModule Apply(IRModule mod) { - IRModule ret_mod_ = IRModule(); - tvm::meta_schedule::ApplyHistoryBest ahb(db_, nullptr, nullptr); - for (auto& p : mod->functions) { - GlobalVar gv = p.first; - BaseFunc func = p.second; - BaseFunc newfunc = func; - if (func->IsInstance()) { - IRModule tir_mod(Map({{gv, func}})); - ObjectRef res = - ahb->Query(gv->name_hint, mod, target_, Array{tir_mod}, nullptr, nullptr); - // replace the tir func only when the schedule is found in tuning database. - if (res.defined()) { - IRModule newmod = Downcast(res); - ICHECK_EQ(newmod->functions.size(), 1); - newfunc = (*newmod->functions.begin()).second; +Pass MetaScheduleApplyDatabase() { + using tvm::meta_schedule::Database; + Target target = Target::Current(false); + Database database = Database::Current().value(); + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + Map result; + for (const auto& iter : mod->functions) { + GlobalVar gv = iter.first; + BaseFunc base_func = iter.second; + if (const auto* prim_func = base_func.as()) { + if (Optional sch = database->QuerySchedule( + IRModule({{gv, GetRef(prim_func)}}), target, gv->name_hint)) { + IRModule new_mod = sch.value()->mod(); + ICHECK_EQ(new_mod->functions.size(), 1); + BaseFunc new_base_func = (*new_mod->functions.begin()).second; + result.Set(gv, new_base_func); + continue; } } - - ret_mod_->Add(gv, newfunc); + result.Set(gv, base_func); } - return ret_mod_; - } - - private: - const tvm::meta_schedule::Database& db_; - Target target_; -}; - -namespace transform { - -Pass MetaScheduleTuneIRMod(Target target, Array config, String work_dir) { - runtime::TypedPackedFunc pass_func = [=](IRModule m, - PassContext ctx) { - return MetaScheduleTuner(target, config, work_dir).TuneIRMod(m, ctx); + return IRModule(result); }; - return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, - /*pass name*/ "MetaScheduleTuneIRModule", - /*required*/ {}, - /*traceable*/ true); -} - -Pass MetaScheduleTuneTIR(Target target, Array config, String work_dir) { - runtime::TypedPackedFunc pass_func = - [=](tir::PrimFunc f, IRModule mod, PassContext ctx) { - return MetaScheduleTuner(target, config, work_dir).TuneTIR(f, ctx); - }; - return tir::transform::CreatePrimFuncPass(/*pass function*/ pass_func, /*opt level*/ 0, - /*pass name*/ "MetaScheduleTuneTIR", - /*required*/ {}, - /*traceable*/ true); -} - -Pass MetaScheduleApplyHistoryBest(const tvm::meta_schedule::Database& database, Target target) { - runtime::TypedPackedFunc pass_func = - [=](IRModule m, PassContext ctx) { return MetaScheduleAHB(database, target).Apply(m); }; - return CreateModulePass(/*pass function*/ pass_func, /*opt level*/ 0, - /*pass name*/ "MetaScheduleApplyHistoryBest", - /*required*/ {}); + return CreateModulePass(pass_func, 0, "MetaScheduleApplyDatabase", {}); } -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneIRMod").set_body_typed(MetaScheduleTuneIRMod); -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleTuneTIR").set_body_typed(MetaScheduleTuneTIR); -TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyHistoryBest") - .set_body_typed(MetaScheduleApplyHistoryBest); +TVM_REGISTER_GLOBAL("relax.transform.MetaScheduleApplyDatabase") + .set_body_typed(MetaScheduleApplyDatabase); } // namespace transform } // namespace relax diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index 0e11696379..177f890d56 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -28,8 +28,19 @@ #include #include "../../../meta_schedule/utils.h" + +namespace tvm { +namespace meta_schedule { + +void JSONFileAppendLine(const String& path, const std::string& line); +std::vector JSONFileReadLines(const String& path, int num_threads, bool allow_missing); + +} // namespace meta_schedule +} // namespace tvm + namespace tvm { namespace relax { + TuningRecord::TuningRecord(Trace trace, Optional> run_secs) { ObjectPtr n = make_object(); n->trace = trace; @@ -94,7 +105,8 @@ inline std::string get_database_key(int workload_idx, Target target) { return std::to_string(workload_idx) + "/" + target->str(); } -/*! \brief The default database implementation, which mimics two database tables with two files. */ +/*! \brief The default database implementation, which mimics two database tables with two files. + */ class JSONDatabaseNode : public DatabaseNode { public: /*! \brief The path to the workload table */ diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index a81c56922d..addf129284 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -26,11 +26,15 @@ namespace ir_builder { namespace ir { void IRModuleFrameNode::ExitWithScope() { - ICHECK_EQ(functions.size(), global_vars.size()); - int n = functions.size(); Map func_map; - for (int i = 0; i < n; ++i) { - func_map.Set(global_vars[i], functions[i]); + CHECK_EQ(functions.size(), global_var_map.size()) + << "All functions must be defined in the IRModule. Got " << global_var_map.size() + << "declared function(s), but only " << functions.size() << "defined function(s)."; + for (const auto& kv : functions) { + const GlobalVar& gv = kv.first; + const BaseFunc& func = kv.second; + CHECK(func.defined()) << "ValueError: function " << gv->name_hint << " is not defined"; + func_map.Set(gv, func); } IRBuilder builder = IRBuilder::Current(); ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc index a8cc452e4f..de8a7a3b09 100644 --- a/src/script/ir_builder/ir/ir.cc +++ b/src/script/ir_builder/ir/ir.cc @@ -20,6 +20,8 @@ #include #include +#include "./utils.h" + namespace tvm { namespace script { namespace ir_builder { @@ -27,12 +29,34 @@ namespace ir { IRModuleFrame IRModule() { ObjectPtr n = make_object(); - n->global_vars.clear(); + n->global_var_map.clear(); n->functions.clear(); return IRModuleFrame(n); } +GlobalVar DeclFunction(const String& func_name) { + IRModuleFrame frame = FindModuleFrame("I.DeclFunction"); + CHECK(!frame->global_var_map.count(func_name)) + << "ValueError: function " << func_name << " already exists"; + GlobalVar gv = GlobalVar(func_name); + frame->global_var_map.Set(func_name, gv); + return gv; +} + +void DefFunction(const String& func_name, const BaseFunc& func) { + IRModuleFrame frame = FindModuleFrame("I.DefFunction"); + auto it = frame->global_var_map.find(func_name); + CHECK(it != frame->global_var_map.end()) + << "ValueError: function " << func_name << " does not exist, please declare it first."; + const GlobalVar& gv = (*it).second; + CHECK(frame->functions.find(gv) == frame->functions.end()) + << "ValueError: function " << func_name << " has already been defined."; + frame->functions.Set(gv, func); +} + TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction); +TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction); } // namespace ir } // namespace ir_builder diff --git a/src/script/ir_builder/ir/utils.h b/src/script/ir_builder/ir/utils.h new file mode 100644 index 0000000000..58d5e53f70 --- /dev/null +++ b/src/script/ir_builder/ir/utils.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ + +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace ir { + +inline IRModuleFrame FindModuleFrame(const String& method) { + IRBuilder builder = IRBuilder::Current(); + if (Optional frame = builder->FindFrame()) { + const Optional& last_module_frame = builder->GetLastFrame(); + if (last_module_frame.defined() && last_module_frame.value() == frame) { + return frame.value(); + } + } else { + LOG(FATAL) << "ValueError: IRModule frame not find. Please ensure '" << method + << "' is called under I.ir_module()"; + } + LOG(FATAL) << "ValueError: '" << method << "' must be called immediately under I.ir_module()"; + throw; +} + +} // namespace ir +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_IR_UTILS_H_ diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index f319c21191..8a7c2ff538 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -40,6 +40,7 @@ void SeqExprFrameNode::ExitWithScope() { } void FunctionFrameNode::ExitWithScope() { + using ir::IRModuleFrame; using tvm::relax::Expr; SeqExprFrameNode::ExitWithScope(); IRBuilder builder = IRBuilder::Current(); diff --git a/src/script/ir_builder/relax/ir.cc b/src/script/ir_builder/relax/ir.cc new file mode 100644 index 0000000000..f102e49650 --- /dev/null +++ b/src/script/ir_builder/relax/ir.cc @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include "./utils.h" + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +///////////////////////////////// Vars ////////////////////////////////// + +using tvm::script::ir_builder::details::Namer; + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::VarNode; + const VarNode* var = node.as(); + relay::IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +TVM_STATIC_IR_FUNCTOR(Namer, vtable) + .set_dispatch([](const ObjectRef& node, String name) -> void { + using tvm::relax::DataflowVarNode; + const DataflowVarNode* var = node.as(); + relay::IdNode* vid = const_cast(var->vid.get()); + vid->name_hint = name; + }); + +////////////////////////////// Tensor Type ////////////////////////////// + +TensorType::TensorType(tvm::relax::DynTensorType type, Optional shape) { + auto n = make_object(); + n->type = std::move(type); + n->shape = std::move(shape); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(TensorTypeNode); + +TensorType Tensor(Optional> shape, DataType dtype, int ndim) { + using namespace tvm::relax; + if (shape.defined() && ndim >= 0) { + CHECK_EQ(shape.value().size(), ndim) + << "The dimension of the given shape is mismatched with the given `ndim`"; + } else if (shape.defined()) { + ndim = shape.value().size(); + } + Optional shape_expr = NullOpt; + if (shape.defined()) { + shape_expr = ShapeExpr(shape.value()); + } + return TensorType(DynTensorType(ndim, dtype), shape_expr); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Tensor").set_body_typed(Tensor); + +/////////////////////////////// Function //////////////////////////////// + +FunctionFrame Function() { + ObjectPtr n = make_object(); + n->block_builder = tvm::relax::BlockBuilder::Create(/*mod=*/NullOpt); + return FunctionFrame(n); +} + +tvm::relax::Var Arg(const String& name, const Type& type, const tvm::relax::ShapeExpr& shape) { + FunctionFrame frame = FindFunctionFrame("R.Arg"); + tvm::relax::Var var(name, shape, type); + frame->params.push_back(var); + return var; +} + +void FuncName(const String& name) { + FunctionFrame frame = FindFunctionFrame("R.func_name"); + if (frame->name.defined()) { + LOG(FATAL) << "ValueError: Duplicate function name, previous one is: \"" << frame->name.value() + << "\""; + } + frame->name = name; +} + +void FuncAttrs(Map attrs) { + FunctionFrame frame = FindFunctionFrame("R.func_attr"); + if (!frame->attrs.empty()) { + LOG(FATAL) << "ValueError: Duplicate function attrs, previous one is:\n" << frame->attrs; + } + frame->attrs = attrs; +} + +void FuncRetType(tvm::Type ret_type) { + FunctionFrame frame = FindFunctionFrame("R.ret_type"); + if (frame->ret_type.defined()) { + LOG(FATAL) << "ValueError: Duplicate function return type, previous one is:\n " + << frame->ret_type.value(); + } + frame->ret_type = ret_type; +} + +void FuncRetValue(const tvm::relax::Expr& value) { + // Step 1. The current Relax TVMScript syntax only allows function return appearing at the end of + // a function body. Therefore if there is any unended block frame when dealing with function + // return, we should end the block frame. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + block_frame.value()->ExitWithScope(); + ICHECK(!IRBuilder::Current()->FindFrame()) + << "All block frame are supposed to be popped out already"; + } + // Step 2. Add the output value to the function frame. + FunctionFrame frame = FindFunctionFrame("return"); + CHECK(!frame->output.defined()) + << "ValueError: Relax functions don't support multiple return statement. Please make sure " + "the return statement appears at the end of function."; + frame->output = value; +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Function").set_body_typed(Function); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Arg").set_body_typed(Arg); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncName").set_body_typed(FuncName); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncAttrs").set_body_typed(FuncAttrs); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetType").set_body_typed(FuncRetType); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.FuncRetValue").set_body_typed(FuncRetValue); + +///////////////////////////// BindingBlock ////////////////////////////// + +BlockFrame Dataflow() { + ObjectPtr n = make_object(); + n->is_dataflow = true; + n->block_ended = false; + return BlockFrame(n); +} + +BlockFrame BindingBlock() { + ObjectPtr n = make_object(); + n->is_dataflow = false; + n->block_ended = false; + return BlockFrame(n); +} + +void DataflowBlockOutput(const Array& vars) { + // Step 1. Check that we're in a Dataflow block that is not ended. + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + CHECK(block_frame.defined() && block_frame.value()->is_dataflow) + << "ValueError: `R.output` should appear inside a dataflow block. However, the current " + "innermost block is not a dataflow block."; + CHECK(!block_frame.value()->block_ended) + << "ValueError: It is not allowed for a dataflow block to have multiple output operation."; + + // Step 2. Mark the block frame ended of construction, so that any followup binding after this + // mark in the dataflow block will lead to an error. + block_frame.value()->block_ended = true; + + // Step 3. All the output variables must be global variables and must be emitted by this dataflow + // block. + Array emitted_vars = block_frame.value()->emitted_vars; + for (const tvm::relax::Var& var : vars) { + CHECK(!var->IsInstance()) + << "ValueError: The output variables of a dataflow block must be all global variables."; + CHECK(std::find(emitted_vars.begin(), emitted_vars.end(), var) != emitted_vars.end()) + << "ValueError: An output variable is not emitted by this dataflow block. Please make sure " + "all dataflow block output variables are emitted exactly by this block."; + } + + // Step 4. All normal variables emitted by this dataflow blocks should be output variables. + for (const tvm::relax::Var& emitted_var : emitted_vars) { + if (!emitted_var->IsInstance()) { + CHECK(std::find(vars.begin(), vars.end(), emitted_var) != vars.end()) + << "ValueError: An non-dataflow variable of this dataflow block is not an output " + "variable. Please make sure all non-dataflow variables emitted by this block are all " + "contained in the output variable list."; + } + } +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Dataflow").set_body_typed(Dataflow); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.BindingBlock").set_body_typed(BindingBlock); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.DataflowBlockOutput") + .set_body_typed(DataflowBlockOutput); + +/////////////////////////////// Bindings /////////////////////////////// + +tvm::relax::Var Emit(const tvm::relax::Expr& expr, bool is_dataflow_var) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Var var{nullptr}; + if (block_frame->is_dataflow && !is_dataflow_var) { + var = block_builder->EmitOutput(expr); + } else { + var = block_builder->Emit(expr); + } + block_frame->emitted_vars.push_back(var); + return var; +} + +Optional EmitMatchShape(const tvm::relax::Expr& value, // + const Array& pattern, // + bool emit_var, // + bool is_dataflow_var) { + BlockFrame block_frame = CheckBlockFrameExistAndUnended(); + tvm::relax::BlockBuilder block_builder = GetBlockBuilder(); + + // If we don't intend to emit a variable, just emit the binding and return. + if (!emit_var) { + tvm::relax::MatchShape match_shape(value, pattern, tvm::relax::Var{nullptr}); + block_builder->EmitMatchShape(match_shape); + return NullOpt; + } + + // TODO(tvm-team): Enhance the API of EmitMatchShape in BlockBuilder and then update the following + // code snippet + tvm::relax::Var var{nullptr}; + tvm::relax::Id vid(is_dataflow_var ? "lv" : "gv"); + + if (is_dataflow_var) { + var = tvm::relax::DataflowVar(vid, NullOpt, NullOpt); + } else { + var = tvm::relax::Var(vid, NullOpt, NullOpt); + } + + if (value->checked_type().as()) { + UpdateType(var, tvm::relax::ShapeType()); + } else if (const tvm::relax::DynTensorTypeNode* tty = + value->checked_type().as()) { + tvm::relax::ShapeExpr shape = tvm::relax::ShapeExpr(pattern); + UpdateShape(var, shape); + DataType dtype = tty->dtype; + UpdateType(var, tvm::relax::DynTensorType(pattern.size(), dtype)); + } else { + LOG(FATAL) << "The value passed to EmitMatchShape must be of DynTensorType or ShapeType."; + } + + block_frame->emitted_vars.push_back(var); + return block_builder->EmitMatchShape(tvm::relax::MatchShape(value, pattern, var)); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Emit").set_body_typed(Emit); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.EmitMatchShape").set_body_typed(EmitMatchShape); + +///////////////////////////// Type Deduce ////////////////////////////// + +void AnnotateTypeShape(const tvm::relax::Var& var, const Type& type, + const Optional& shape) { + using tvm::relax::IsBaseOf; + if (!var->checked_type_.defined()) { + var->checked_type_ = type; + } else { + const Type& var_type = var->checked_type(); + if (IsBaseOf(type, var_type)) { + // The var type is equal or more detailed than annotated one, do nothing. + } else if (IsBaseOf(var_type, type)) { + LOG(WARNING) << "The inferred type of var " << var->name_hint() + << " by the block builder is more refined than the annotated one. The system " + "will refine it automatically."; + var->checked_type_ = type; + } else { + LOG(FATAL) << "TypeError: The annotated type and value type are not compatible. " + << "The Type is expected to be " << var_type << " but got annotation: " << type; + } + } + + if (!var->shape_.defined()) { + var->shape_ = shape; + } else if (shape.defined()) { + const tvm::relax::BlockBuilder& block_builder = GetBlockBuilder(); + tvm::relax::Expr var_shape = Downcast(var->shape_.value()); + CHECK(block_builder->CanProveShapeEqual(var_shape, shape.value())) + << " The shape of var " << var->name_hint() << " is expected to be " << var_shape + << " but got annotation: " << shape.value(); + } +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.AnnotateTypeShape").set_body_typed(AnnotateTypeShape); + +///////////////////////////// If Then Else ///////////////////////////// + +IfFrame If(tvm::relax::Expr condition) { + ObjectPtr n = make_object(); + n->condition = condition; + n->then_expr = NullOpt; + n->else_expr = NullOpt; + return IfFrame(n); +} + +ThenFrame Then() { + ObjectPtr n = make_object(); + return ThenFrame(n); +} + +ElseFrame Else() { + ObjectPtr n = make_object(); + return ElseFrame(n); +} + +TVM_REGISTER_GLOBAL("script.ir_builder.relax.If").set_body_typed(If); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Then").set_body_typed(Then); +TVM_REGISTER_GLOBAL("script.ir_builder.relax.Else").set_body_typed(Else); + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm diff --git a/src/script/ir_builder/relax/utils.h b/src/script/ir_builder/relax/utils.h new file mode 100644 index 0000000000..e55957cdbf --- /dev/null +++ b/src/script/ir_builder/relax/utils.h @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ +#define TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ + +#include +#include + +namespace tvm { +namespace script { +namespace ir_builder { +namespace relax { + +inline FunctionFrame FindFunctionFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->FindFrame()) { + return frame.value(); + } + LOG(FATAL) << "ValueError: Function frame not find. Please ensure '" << method + << "' is called under R.function()"; + throw; +} + +inline IfFrame FindIfFrame(const String& method) { + if (Optional frame = IRBuilder::Current()->GetLastFrame()) { + return frame.value(); + } else { + LOG(FATAL) << "ValueError: IfThenElse frame not find. Please ensure '" << method + << "' is called under R.if_()"; + } + throw; +} + +inline tvm::relax::BlockBuilder GetBlockBuilder() { + Optional frame = IRBuilder::Current()->FindFrame(); + CHECK(frame.defined()) << "ValueError: Relax Function frame not find. Please ensure " + "assignment is called under R.function()"; + return frame.value()->block_builder; +} + +inline BlockFrame CheckBlockFrameExistAndUnended() { + // - If we're emitting a non-dataflow binding in the function (that is to say, the binding is not + // wrapped by `with R.dataflow()`), it is possible that there is no existing BlockFrame. In this + // case, we will create a BlockFrame and "enter its 'with' scope" first. + // - Otherwise, there is already an existing BlockFrame. We check if the block is "ended" - if a + // block is ended, it is not allowed to emit new bindings into this block, and we should throw + // exceptions. + + Optional block_frame = IRBuilder::Current()->GetLastFrame(); + if (block_frame.defined()) { + CHECK(!block_frame.value()->block_ended) + << "ValueError: New binding is not allowed after dataflow block output."; + return block_frame.value(); + } + + BlockFrame new_block_frame = BindingBlock(); + new_block_frame->EnterWithScope(); + return new_block_frame; +} + +inline tvm::relax::SeqExpr GetSeqExprForBranch(const SeqExprFrame& frame, String* var_name) { + // Step 0. Check frame type + std::string method; + if (frame->IsInstance()) { + method = "R.Then"; + } else if (frame->IsInstance()) { + method = "R.Else"; + } else { + ICHECK(false) << "TypeError: Unsupported frame type: " << frame->GetTypeKey(); + } + + // Step 1. Check non-empty block and last binding is non-dataflow + CHECK(!frame->binding_blocks.empty()) + << "Empty body is not allowed for '" << method << "' statements."; + const tvm::relax::BindingBlock& last_block = frame->binding_blocks.back(); + CHECK(!last_block->bindings.empty()) << "Blocks are expected to be non-empty."; + + // Step 2. Collect body from the last binding. + tvm::relax::Expr body; + const tvm::relax::Binding& last_binding = last_block->bindings.back(); + if (const auto* var_binding = last_binding.as()) { + CHECK(!var_binding->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = var_binding->var->name_hint(); + } else if (const auto* match_shape = last_binding.as()) { + CHECK(match_shape->var.defined() && + !match_shape->var->IsInstance()) + << "A non-dataflow var is expected in the last binding of '" << method << "'."; + body = var_binding->value; + *var_name = match_shape->var->name_hint(); + } else { + ICHECK(false) << "TypeError: Unsupported binding type: " << last_binding->GetTypeKey(); + } + + // Step 3. Re-collect binding blocks to remove the last binding. + Array new_blocks(frame->binding_blocks.begin(), + frame->binding_blocks.end() - 1); + Array last_block_bindings(last_block->bindings.begin(), + last_block->bindings.end() - 1); + new_blocks.push_back(tvm::relax::BindingBlock(last_block_bindings)); + + return tvm::relax::SeqExpr(new_blocks, body); +} + +} // namespace relax +} // namespace ir_builder +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_IR_BUILDER_RELAX_UTILS_H_ diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 1e63201a40..57ba54c253 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include @@ -41,9 +42,17 @@ void PrimFuncFrameNode::ExitWithScope() { ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; builder->result = func; } else if (Optional opt_frame = builder->FindFrame()) { - ir::IRModuleFrame frame = opt_frame.value(); - frame->global_vars.push_back(GlobalVar(name.value_or(""))); - frame->functions.push_back(func); + CHECK(name.defined()) << "ValueError: The function name must be defined before exiting the " + "function scope, if it's defined in a Module"; + const ir::IRModuleFrame& frame = opt_frame.value(); + const String& func_name = name.value_or(""); + if (!frame->global_var_map.count(func_name)) { + // Case. First time visiting the function. + ir::DeclFunction(func_name); + } + // Define the function. + // Note we do checks to disallow redefinition of functions inside the `DefFunction`. + ir::DefFunction(func_name, func); } else { LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; } diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index e3dc3062f5..2837263993 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -551,8 +551,8 @@ PrimFunc GenerateAndCompletePrimFunc(const Array& arg_list, PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, - std::optional index_dtype_override) - const Optional>& tir_var_list) { + const Optional>& tir_var_list, + std::optional index_dtype_override) { // Infomations used in CreatePrimFunc and its sub-functions. CreateFuncInfo info(arg_list); // Root body stmts. @@ -580,19 +580,20 @@ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, } PrimFunc CreatePrimFunc(const Array& arg_list, + const Optional> tir_var_list, std::optional index_dtype_override) { - return CreatePrimFuncWithConstants(arg_list, {}, index_dtype_override); + return CreatePrimFuncWithConstants(arg_list, {}, tir_var_list, index_dtype_override); } TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body([](TVMArgs args, TVMRetValue* ret) { Array arg_list = args[0]; + Optional> tir_var_list = args[1]; std::optional index_dtype_override{std::nullopt}; // Add conversion to make std::optional compatible with FFI. - ICHECK_EQ(args.size(), 2); - if (args[1].type_code() != kTVMNullptr) { - index_dtype_override = args[1].operator DataType(); + if (args[2].type_code() != kTVMNullptr) { + index_dtype_override = args[2].operator DataType(); } - *ret = CreatePrimFunc(arg_list, index_dtype_override); + *ret = CreatePrimFunc(arg_list, tir_var_list, index_dtype_override); }); } // namespace tir diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 00442bb920..483c59f324 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -40,7 +40,7 @@ PrimFunc CreatePrimFunc(const Array& arg_list, */ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants, - std::optional index_dtype_override = std::nullopt); + const Optional>& tir_var_list); } // namespace tir } // namespace tvm diff --git a/tests/python/relax/test_autotir_integration.py b/tests/python/relax/test_autotir_integration.py index 1635598a63..7420b76e9a 100644 --- a/tests/python/relax/test_autotir_integration.py +++ b/tests/python/relax/test_autotir_integration.py @@ -16,21 +16,20 @@ # under the License. from __future__ import annotations -import numpy as np -import pytest import tempfile import time + +import numpy as np +import pytest import tvm import tvm.testing - from tvm import meta_schedule as ms -from tvm import relax -from tvm import transform +from tvm import relax, transform from tvm.ir.module import IRModule -from tvm.script import relax as R, tir as T +from tvm.script import relax as R +from tvm.script import tir as T from tvm.target.target import Target - # Test case with dynamic shape. # Tuning with dynamic shape is not supported yet. """ @@ -132,18 +131,22 @@ def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tens database = ms.database.MemoryDatabase() with tempfile.TemporaryDirectory() as work_dir: - relax_ex = ms.tune_relax( + db = ms.relax_integration.tune_relax( mod=mod, target=target, - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=2, - max_trials_per_task=4, - max_trials_global=4, - ), + params=None, + num_trials_per_iter=2, + max_trials_per_task=4, + max_trials_global=4, work_dir=work_dir, database=database, ) + relax_ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=None, + ) if dev == "cpu": with transform.PassContext(opt_level=3): @@ -175,7 +178,7 @@ def test_autotir_gpu(): test_autotir("cuda") -def test_meta_schedule_extract_task_from_relax(): +def test_meta_schedule_extract_tasks(): @tvm.script.ir_module class Module: @T.prim_func @@ -220,7 +223,7 @@ def main(x: Tensor((128, 128), "float32")) -> Tensor(_, "float32"): relax.output(gv) return gv - tasks = ms.relax_integration.extract_task_from_relax(Module, Target("llvm --num-cores=16")) + tasks = ms.relax_integration.extract_tasks(Module, Target("llvm --num-cores=16")) expected_weights = {"add1": 3, "add2": 1, "multiply1": 2} assert len(tasks) == len(expected_weights) for task in tasks: diff --git a/tests/python/relax/test_relay_translator.py b/tests/python/relax/test_relay_translator.py index 8a1dcf1b49..d071ffae63 100644 --- a/tests/python/relax/test_relay_translator.py +++ b/tests/python/relax/test_relay_translator.py @@ -15,19 +15,20 @@ # specific language governing permissions and limitations # under the License. +import tempfile + +import numpy as np +import pytest import tvm -from tvm.ir.base import assert_structural_equal -from tvm.runtime import vm import tvm.testing -from tvm.relay import testing +from tvm import meta_schedule as ms from tvm import relax, relay +from tvm.ir.base import assert_structural_equal from tvm.relax.testing import relay_translator -from tvm import meta_schedule as ms -from tvm.target import Target -import numpy as np -import pytest -import tempfile +from tvm.relay import testing +from tvm.runtime import vm from tvm.script import tir as T +from tvm.target import Target def get_resnet(batch_size, dtype, layout, image_shape): @@ -44,41 +45,47 @@ def get_resnet(batch_size, dtype, layout, image_shape): def relay_build_and_run(mod, target, dev, params, data): with tempfile.TemporaryDirectory() as work_dir: - ex = ms.tune_relay( + db = ms.relay_integration.tune_relay( mod=mod, params=params, target=target, - config=ms.EvolutionarySearchConfig( - num_trials_per_iter=32, - max_trials_per_task=3, - max_trials_global=300, - ), - task_scheduler="round_robin", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, + task_scheduler="round-robin", work_dir=work_dir, ) - + ex = ms.relay_integration.compile_relay( + db, + mod=mod, + target=target, + params=params, + ) rt_mod = tvm.contrib.graph_executor.GraphModule(ex["default"](dev)) rt_mod.set_input("data", data) rt_mod.run() - out = rt_mod.get_output(0).asnumpy() + out = rt_mod.get_output(0).numpy() return ex, rt_mod, out def relax_build_and_run(mod, target, dev, params, data): mod = relax.transform.BindParams("main", params)(mod) with tempfile.TemporaryDirectory() as work_dir: - ex = ms.tune_relax( + db = ms.relax_integration.tune_relax( mod=mod, target=target, - config=ms.TuneConfig( - strategy="evolutionary", - task_scheduler="round_robin", - num_trials_per_iter=32, - max_trials_per_task=3, - max_trials_global=300, - ), + task_scheduler="round-robin", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=1024, work_dir=work_dir, ) + ex = ms.relax_integration.compile_relax( + db, + mod=mod, + target=target, + params=params, + ) vm = relax.VirtualMachine(ex, dev) res = vm["main"](data) out = res.numpy() @@ -91,12 +98,9 @@ def verify_e2e_translation(target_str, layout, batch_size, image_shape): relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) input_shape = (1, *image_shape) data = tvm.nd.array(np.random.rand(*input_shape).astype(np.float32), dev) - relax_mod = relay_translator.from_relay(relay_mod["main"], target, params) - - relay_ex, relay_rt_mod, relay_out = relay_build_and_run(relay_mod, target, dev, params, data) - relax_ex, relax_rt_mod, relax_out = relax_build_and_run(relax_mod, target, dev, params, data) - + _, _, relay_out = relay_build_and_run(relay_mod, target, dev, params, data) + _, _, relax_out = relax_build_and_run(relax_mod, target, dev, params, data) tvm.testing.assert_allclose(relay_out, relax_out, atol=1e-5, rtol=1e-5) @@ -120,7 +124,6 @@ def test_verify_e2e_translation_gpu(layout, batch_size, image_shape): def verify_extracted_tasks(target_str, layout, batch_size, image_shape): target = Target(target_str) relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape) - relax_mod = relay_translator.from_relay( relay_mod["main"], target, @@ -130,8 +133,7 @@ def verify_extracted_tasks(target_str, layout, batch_size, image_shape): "relay.FuseOps.max_depth": 1, # Disable relay fusion }, ) - - relay_tasks = ms.extract_task_from_relay( + relay_tasks = ms.relay_integration.extract_tasks( relay_mod, target=target, params=params, @@ -140,8 +142,11 @@ def verify_extracted_tasks(target_str, layout, batch_size, image_shape): "relay.FuseOps.max_depth": 1, # Disable relay fusion }, ) - - relax_tasks = ms.extract_task_from_relax(relax_mod, target=target, params=params) + relax_tasks = ms.relax_integration.extract_tasks( + relax_mod, + target=target, + params=params, + ) # TODO (yongwww, yuchen): tophub guides relay passes, which causes inconsistent tasks # assert len(relay_tasks) == len(relax_tasks) # TODO: Can we compare extracted tasks as well? @@ -261,8 +266,7 @@ def tir_matmul( a = relay.var("a", shape=shape) relay_mod = tvm.IRModule.from_expr(relay.Function([a], a * a)) - - relay_vm, relax_vm, relax_mod = translate_and_build_vms( + _, _, relax_mod = translate_and_build_vms( relay_mod, translate_op_with_tir={"multiply": tir_matmul} ) assert_structural_equal(relax_mod["multiply"], tir_matmul) diff --git a/tests/python/relax/test_transform_lower_with_op_strategy.py b/tests/python/relax/test_transform_lower_with_op_strategy.py index 8e0e9d2e0f..148df11014 100644 --- a/tests/python/relax/test_transform_lower_with_op_strategy.py +++ b/tests/python/relax/test_transform_lower_with_op_strategy.py @@ -16,17 +16,19 @@ # under the License. from __future__ import annotations -import pytest -import numpy as np + import tempfile + +import numpy as np +import pytest import tvm import tvm.script import tvm.testing +from tvm import meta_schedule as ms from tvm import relax -from tvm.target import Target from tvm.relax.testing import transform from tvm.script import relax as R -from tvm import meta_schedule as ms +from tvm.target import Target @tvm.script.ir_module @@ -43,18 +45,16 @@ def main( def build_and_run(mod, target, dev, np_inputs): inputs = [tvm.nd.array(np_input, dev) for np_input in np_inputs] with tempfile.TemporaryDirectory() as work_dir: - ex = ms.tune_relax( + db = ms.relax_integration.tune_relax( mod=mod, + params=None, target=target, - config=ms.TuneConfig( - strategy="evolutionary", - task_scheduler="round_robin", - num_trials_per_iter=20, - max_trials_per_task=20, - max_trials_global=20, - ), work_dir=work_dir, + num_trials_per_iter=20, + max_trials_global=20, + task_scheduler="round-robin", ) + ex = ms.relax_integration.compile_relax(db, mod, target, params=None) vm = relax.VirtualMachine(ex, dev) vm["main"](*inputs) @@ -83,4 +83,5 @@ def test_lowering_gpu(target_str="nvidia/nvidia-t4"): if __name__ == "__main__": - pytest.main([__file__]) + test_lowering_cpu() + test_lowering_gpu() diff --git a/tests/python/relax/test_transform_meta_schedule_tuning.py b/tests/python/relax/test_transform_meta_schedule_tuning.py index 0806c403ae..dc34b7e45f 100644 --- a/tests/python/relax/test_transform_meta_schedule_tuning.py +++ b/tests/python/relax/test_transform_meta_schedule_tuning.py @@ -16,19 +16,23 @@ # under the License. from __future__ import annotations -import pytest + import tempfile + +import pytest import tvm +import tvm.meta_schedule as ms +from tvm import relax from tvm.ir import transform -from tvm.ir.transform import PassContext from tvm.ir.module import IRModule -from tvm.script import tir as T, relax as R -from tvm import relax -import tvm.meta_schedule as ms +from tvm.ir.transform import PassContext from tvm.relax.transform.tuning_api import Trace +from tvm.script import relax as R +from tvm.script import tir as T -def test_metaschedule_tuning(): +@pytest.mark.xfail(reason="TuningAPI is broken after rebase") +def test_meta_schedule_tuning(): @tvm.script.ir_module class InputModule: @T.prim_func @@ -70,7 +74,7 @@ def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tens mod = InputModule assert isinstance(mod, IRModule) - target_str = "llvm --num-cores=16" + target = tvm.target.Target("llvm --num-cores=16") config = ms.TuneConfig( strategy="evolutionary", num_trials_per_iter=2, @@ -80,16 +84,14 @@ def main(x: Tensor((32, 32), "float32"), w: Tensor((32, 32), "float32")) -> Tens with tempfile.TemporaryDirectory() as work_dir: seq = transform.Sequential( - [relax.transform.MetaScheduleTuneIRMod(tvm.target.Target(target_str), config, work_dir)] + [relax.transform.MetaScheduleTuneIRMod(target, config, work_dir)] ) with transform.PassContext(trace=Trace(mod), opt_level=0): _ = seq(mod) assert PassContext.current().get_trace_stack_size() == 1 assert PassContext.current().get_current_trace().size == 1 - seq = transform.Sequential( - [relax.transform.MetaScheduleTuneTIR(tvm.target.Target(target_str), config, work_dir)] - ) + seq = transform.Sequential([relax.transform.MetaScheduleTuneTIR(target, config, work_dir)]) with transform.PassContext(trace=Trace(mod), opt_level=0): _ = seq(mod) assert PassContext.current().get_trace_stack_size() == 1 diff --git a/tests/python/relax/test_tvmscript_ir_builder.py b/tests/python/relax/test_tvmscript_ir_builder.py new file mode 100644 index 0000000000..a80f71d62a --- /dev/null +++ b/tests/python/relax/test_tvmscript_ir_builder.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax, tir +from tvm.script.ir_builder import relax as R +from tvm.script.ir_builder.base import IRBuilder + + +def test_function_simple(): + """ + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + out = R.call_tir("extern_func", x, (128, 128), dtype="float32") + return out + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + R.func_attr({"Primitive": 1}) + x = R.arg("x", R.tensor((128, 128), "float32")) + R.func_ret_type(R.tensor(dtype="float32", ndim=2)) + out = R.emit( + R.call_tir("extern_func", x, (128, 128), dtype="float32"), is_dataflow_var=False + ) + IRBuilder.name("out", out) + R.func_ret_value(out) + func = ir_builder.get() + # create with BlockBuilder + x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + bb.emit_func_output(out) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + # check names + assert func.attrs["global_symbol"] == "foo" + assert func.params[0].name_hint == "x" + assert func.body.body.name_hint == "out" + + +def test_match_shape(): + """ + @R.function + def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + m = T.var("int64") + n = T.var("int64") + R.match_shape(x, (m,)) + y1 = R.match_shape(x, (n,)) + return (m, n * 2) + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", R.tensor(ndim=-1, dtype="float32")) + y = R.arg("y", R.tensor(ndim=-1, dtype="float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + R.emit_match_shape(x, (m,), emit_var=False, is_dataflow_var=False) + y1 = R.emit_match_shape(y, (n,), emit_var=True, is_dataflow_var=False) + IRBuilder.name("y1", y1) + R.func_ret_value(relax.ShapeExpr([m, n * 2])) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", type_annotation=relax.DynTensorType(-1, "float32")) + y = relax.Var("y", type_annotation=relax.DynTensorType(-1, "float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + bb.match_shape_binding(relax.MatchShape(x, (m,), var=None)) + y1 = bb.match_shape(y, (n,)) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + mod = bb.get() + + tvm.ir.assert_structural_equal(func, mod["foo"]) + + +def test_dataflow_block(): + """ + @R.function + def foo(x: Tensor((128, 128), "float32")) -> Tensor(None, "float32", ndim = 2): + # block 0 + with R.dataflow(): + lv0 = R.call_tir("extern_func", (x,), (128, 128), dtype="float32") + gv: Tensor((128, 128), "float32") = lv0 + R.output(gv) + return gv + """ + # create with Script IRBuilder + with IRBuilder() as ir_builder: + with R.function(): + R.func_name("foo") + x = R.arg("x", R.tensor((128, 128), "float32")) + with R.dataflow(): + lv0 = R.emit( + R.call_tir("extern_func", x, (128, 128), dtype="float32"), is_dataflow_var=True + ) + IRBuilder.name("lv0", lv0) + gv = R.emit(lv0, is_dataflow_var=False) + IRBuilder.name("gv", gv) + R.output(gv) + R.func_ret_value(gv) + func = ir_builder.get() + + # create with BlockBuilder + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + tvm.ir.assert_structural_equal(func, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py new file mode 100644 index 0000000000..8586dc7462 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser.py @@ -0,0 +1,629 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Union + +import pytest +import tvm +import tvm.testing +from tvm import IRModule, relax, tir +from tvm.script._parser import ir as I +from tvm.script._parser import relax as R +from tvm.script._parser import tir as T + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Union[relax.Function, IRModule], +): + # TODO(siyuan): add round-trip tests + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_simple_func(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + R.func_attr({"Primitive": 1}) + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + return gv0 + + x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,), attrs={"Primitive": 1}): + out = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +def test_error_report(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = gv1 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + return gv0 + + +def test_simple_module(): + @I.ir_module + class TestModule: + @T.prim_func + def tir_func(x: T.Buffer((128, 128), "float32"), y: T.Buffer((128, 128), "float32")): + T.func_attr({"global_symbol": "tir_func", "tir.noalias": True}) + for i, j in T.grid(128, 128): + with T.block(): + vi, vj = T.axis.remap("SS", [i, j]) + y[vi, vj] = x[vi, vj] + 1.0 + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + # TODO(Siyuan): Need to change to `TestModule.tir_func` + gv0 = R.call_tir(tir_func, x, (128, 128), dtype="float32") + return gv0 + + x = relax.Var("x", [128, 128], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func") + bb.emit_func_output(out) + + _check(TestModule, bb.get()) + + +def test_relax_tensor_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor(None, "float32", ndim=2): + y = R.add(x, x) + z = R.multiply(x, y) + return z + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_relax_base_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + alloc = R.builtin.alloc_tensor((4, 4), runtime_device_index=0, dtype="float32") + shape = R.shape_of(alloc) + return shape + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) + shape = bb.emit(relax.op.shape_of(alloc)) + bb.emit_func_output(shape) + + _check(foo, bb.get()["foo"]) + + +def test_symbolic_shape(): + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64", "m") + n = T.var("int64", "n") + gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + return gv0 + + @R.function + def bar(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64") + n = T.var("int64") + gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + return gv0 + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def mismatch_dtype(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(None, "float32", ndim=2): + m = T.var("int64") + n = T.var("int32") # The shape dtype should be int64 + gv0 = R.call_tir("extern_func", x, (m, n), dtype="float32") + return gv0 + + def _expected(name: str): + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", [m, n], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function(name, (x,)): + out = bb.emit(relax.call_tir("extern_func", x, (m, n), dtype="float32")) + bb.emit_func_output(out) + return bb.get()[name] + + _check(foo, _expected("foo")) + _check(bar, _expected("bar")) + + +def test_shadowing(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + y = R.add(x, x) + z = R.multiply(x, y) + y = R.add(x, y) + y = z + y = R.multiply(y, x) + z = y + return z + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + y = bb.emit(relax.op.add(x, x)) + z = bb.emit(relax.op.multiply(x, y)) + y = bb.emit(relax.op.add(x, y)) + y = bb.emit(z) + y = bb.emit(relax.op.multiply(y, x)) + z = bb.emit(y) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_match_shape(): + @R.function + def foo(x: R.Tensor(None, "float32"), y: R.Tensor(None, "float32")): + m = T.var("int64") + n = T.var("int64") + R.match_shape(x, (m,)) + y1 = R.match_shape(y, (n,)) + return (m, n * 2) + + x = relax.Var("x", type_annotation=relax.DynTensorType(-1, "float32")) + y = relax.Var("y", type_annotation=relax.DynTensorType(-1, "float32")) + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + bb = relax.BlockBuilder() + with bb.function("foo", (x, y)): + bb.match_shape_binding(relax.MatchShape(x, (m,), var=None)) + y1 = bb.match_shape(y, (n,)) + bb.emit_func_output(relax.ShapeExpr([m, n * 2])) + _check(foo, bb.get()["foo"]) + + +def test_tuple_return(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + gv0 = R.call_tir("extern_func_0", x, (4, 4), dtype="float32") + gv1 = R.call_tir("extern_func_1", x, (4, 4), dtype="float32") + return (gv0, gv1) + + x = relax.Var("x", [4, 4], relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func_0", x, (4, 4), dtype="float32")) + gv1 = bb.emit(relax.call_tir("extern_func_1", x, (4, 4), dtype="float32")) + bb.emit_func_output(relax.Tuple((gv0, gv1))) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + lv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + lv1 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") + gv = lv1 + R.output(gv) + return gv + + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + lv1 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) + gv = bb.emit_output(lv1) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_block_advanced(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + gv1 = R.call_tir("extern_func", gv0, (128, 128), dtype="float32") + with R.dataflow(): + m = T.var("int64") + n = T.var("int64") + lv0 = R.call_tir("extern_func", gv1, (128, 128), dtype="float32") + lv1 = R.match_shape(lv0, (m, n)) + gv2 = R.call_tir("extern_func", lv0, (128, 128), dtype="float32") + gv2 = R.call_tir("extern_func", gv2, (128, 128), dtype="float32") + gv3 = R.match_shape(gv2, (m, n)) + gv3 = R.match_shape(lv0, (m, n)) + gv4 = gv3 + gv5 = gv2 + R.output(gv5, gv4) + gv6 = R.call_tir("extern_func", gv5, (128, 128), dtype="float32") + gv7 = R.call_tir("extern_func", gv6, (128, 128), dtype="float32") + return gv7 + + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + m = tir.Var("m", dtype="int64") + n = tir.Var("n", dtype="int64") + with bb.function("foo", (x,)): + gv0 = bb.emit(relax.call_tir("extern_func", x, (128, 128), dtype="float32")) + gv1 = bb.emit(relax.call_tir("extern_func", gv0, (128, 128), dtype="float32")) + with bb.dataflow(): + lv0 = bb.emit(relax.call_tir("extern_func", gv1, (128, 128), dtype="float32")) + lv1 = bb.match_shape(lv0, (m, n)) + gv2 = bb.emit(relax.call_tir("extern_func", lv0, (128, 128), dtype="float32")) + gv21 = bb.emit(relax.call_tir("extern_func", gv2, (128, 128), dtype="float32")) + gv3 = bb.match_shape(gv21, (m, n)) + gv31 = bb.match_shape(lv0, (m, n)) + gv32 = bb.emit_output(gv31) + gv22 = bb.emit_output(gv21) + gv4 = bb.emit(relax.call_tir("extern_func", gv22, (128, 128), dtype="float32")) + gv5 = bb.emit(relax.call_tir("extern_func", gv4, (128, 128), dtype="float32")) + bb.emit_func_output(gv5) + + _check(foo, bb.get()["foo"]) + + +def test_dataflow_binding_after_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + R.output(gv) + lv = R.call_tir("extern_func", gv, (128, 128), dtype="float32") + return gv + + +def test_dataflow_output_global_var(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + with R.dataflow(): + gv1 = R.call_tir("extern_func", gv0, (128, 128), dtype="float32") + R.output(gv0, gv1) + return gv1 + + +def test_dataflow_multiple_output(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + with R.dataflow(): + gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + R.output(gv) + R.output(gv) + return gv + + +def test_dataflow_output_outside_dataflow_block(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor(None, "float32", ndim=2): + gv = R.call_tir("extern_func", x, (128, 128), dtype="float32") + R.output(gv) + return gv + + +def test_return_without_binding(): + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + + x = relax.Var("x", (128, 128), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_multiple_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + return x + return x + + +def test_function_without_return(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo(x: R.Tensor((128, 128), "float32")): + gv0 = R.call_tir("extern_func", x, (128, 128), dtype="float32") + + +def test_tensor_type_without_args(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + v = R.call_tir("tir_relu", x, (32, 32), dtype="float32") + return v + + x = relax.Var("x", (32, 32), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + v = bb.emit(relax.call_tir("tir_relu", x, (32, 32), dtype="float32")) + bb.emit_func_output(v) + + _check(foo, bb.get()["foo"]) + + +def test_direct_return(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"): + return x + + x = relax.Var("x", (32, 32), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + bb.emit_func_output(x) + + _check(foo, bb.get()["foo"]) + + +def test_call_packed(): + @R.function + def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: + z = R.call_packed("vm.builtin.copy", x, type_args=R.Tensor((32, 32), "float32")) + return z + + x = relax.Var("x", (32, 32), relax.DynTensorType(2, "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x)): + z = bb.emit( + relax.Call( + relax.ExternFunc("vm.builtin.copy"), + (x,), + None, + type_args=[relax.DynTensorType(2, "float32")], + ) + ) + bb.emit_func_output(z) + + _check(foo, bb.get()["foo"]) + + +def test_annotation(): + @R.function + def foo( + x: R.Tensor((32, "m"), "float32"), + y: R.Tensor(("m"), "float32"), + r: R.Tensor(dtype="int64"), + ) -> R.Object: + m = T.var("int64") + z: R.Tensor((32, m), "float32") = R.multiply(x, y) + w: R.Tensor = R.multiply(z, z) + q: R.Tensor(ndim=2) = R.add(w, w) + t = R.add(w, z) + sh: R.Shape = R.shape_of(t) + o: R.Object = R.call_packed("contrib.tensor_array_stack", x, y, type_args=R.Object) + return o + + m = tir.Var("m", "int64") + x = relax.Var("x", (32, m), relax.DynTensorType(2, "float32")) + y = relax.Var("y", (m,), relax.DynTensorType(1, "float32")) + r = relax.Var("r", None, relax.DynTensorType(-1, "int64")) + bb = relax.BlockBuilder() + with bb.function("foo", (x, y, r)): + z = bb.emit(R.multiply(x, y)) + w = bb.emit(R.multiply(z, z)) + q = bb.emit(R.add(w, w)) + t = bb.emit(R.add(w, z)) + sh = bb.emit(R.shape_of(t)) + o = bb.emit( + relax.Call( + relax.ExternFunc("contrib.tensor_array_stack"), + [x, y], + None, + type_args=[relax.ObjectType()], + ) + ) + bb.emit_func_output(o) + + _check(foo, bb.get()["foo"]) + + +def test_empty_shape(): + @R.function + def foo(x: R.Tensor((), "float32")): + z = R.call_tir("scalar_add", x, (), dtype="float32") + return z + + (z_bind,) = foo.body.blocks[0].bindings + shape_expr = z_bind.value.args[2] + + assert isinstance(shape_expr, relax.ShapeExpr) + assert len(shape_expr.values) == 0 + + +def test_local_function(): + @R.function + def main( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 3), "float32") + ) -> R.Tensor((2, 3), "float32"): + @R.function + def outer_func( + c1: R.Tensor((2, 3), "float32") + ) -> R.Callable((R.Tensor(None, "float32", ndim=2),), R.Tensor(None, "float32", ndim=2)): + @R.function + def inner_func(x1: R.Tensor((2, 3), "float32")): + s: R.Tensor((2, 3), "float32") = R.add(x1, c1) + return s + + return inner_func + + in_call = outer_func(x) + res = in_call(y) + return res + + main_bindings = main.body.blocks[0].bindings + assert len(main_bindings) == 3 + outer_func = main_bindings[0].value + assert isinstance(outer_func, relax.Function) + + outer_func_bindings = outer_func.body.blocks[0].bindings + assert len(outer_func_bindings) == 1 + inner_func = outer_func_bindings[0].value + assert isinstance(inner_func, relax.Function) + + @I.ir_module + class TestModule: + @R.function + def f(x: R.Tensor((128, 128), "float32"), y: R.Tensor((128, 128), "float32")): + @T.prim_func + def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + C = T.match_buffer(c, (128, 128)) + + for i, j, k in T.grid(128, 128, 128): + with T.block(): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] += A[vi, vk] * B[vj, vk] + + z = relax.call_tir(my_matmul, (x, y), (128, 128), dtype="float32") + return z + + bindings = TestModule["f"].body.blocks[0].bindings + assert len(bindings) == 2 + tir_func = bindings[0].value + assert isinstance(tir_func, tir.PrimFunc) + + +def test_if_branch(): + @R.function + def foo(cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32")) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return y + + cond, x = foo.params + y_bind = foo.body.blocks[0].bindings[0] + y, ite = y_bind.var, y_bind.value + + assert isinstance(y, relax.Var) + assert y.name_hint == "y" + + assert isinstance(ite, relax.If) + assert isinstance(ite.true_branch, relax.SeqExpr) + assert isinstance(ite.false_branch, relax.SeqExpr) + + def check_call(call, op, args): + assert isinstance(call, relax.Call) + if isinstance(op, str): + assert str(call.op) == op + else: + assert call.op == op + tvm.ir.assert_structural_equal(call.args, args) + + w_bind = ite.true_branch.blocks[0].bindings[0] + body = ite.true_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.add", [x, x]) + check_call(body, "relax.multiply", [w_bind.var, w_bind.var]) + + w_bind = ite.false_branch.blocks[0].bindings[0] + body = ite.false_branch.body + assert w_bind.var.name_hint == "w" + check_call(w_bind.value, "relax.multiply", [x, x]) + check_call(body, "relax.add", [w_bind.var, w_bind.var]) + + +def test_if_inside_dataflow(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor((1,), "float32"): + with R.dataflow(): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + R.output(y) + return y + + +def test_if_branch_output_name(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + z = R.add(w, w) + return y + + +def test_if_branch_var_scope(): + with pytest.raises(tvm.error.DiagnosticError): + + @R.function + def foo( + cond: R.Tensor((), "bool"), x: R.Tensor((1,), "float32") + ) -> R.Tensor((1,), "float32"): + if cond: + w = R.add(x, x) + y = R.multiply(w, w) + else: + w = R.multiply(x, x) + y = R.add(w, w) + return w + + +def test_other_cases(): + # They are corner case tests, which is only to check if it can be parsed. + # No need to add structural equal checks here + @R.function + def foo(x: R.Tensor): + return R.unique(x, sorted=True) + + @R.function + def bar(x: R.Tensor): + return R.print(x, format="{}") + + +if __name__ == "__main__": + tvm.testing.main()