diff --git a/docs/how_to/tutorials/cross_compilation_and_rpc.py b/docs/how_to/tutorials/cross_compilation_and_rpc.py index 81c73fd051ef..94a6f48b4b73 100644 --- a/docs/how_to/tutorials/cross_compilation_and_rpc.py +++ b/docs/how_to/tutorials/cross_compilation_and_rpc.py @@ -119,7 +119,7 @@ else: target = "llvm -mtriple=armv7l-linux-gnueabihf" -func = tvm.build(mod, target=target, name="add_one") +func = tvm.build(mod, target=target) # save the lib at a local temp folder temp = utils.tempdir() path = temp.relpath("lib.tar") diff --git a/docs/reference/api/python/driver.rst b/docs/reference/api/python/driver.rst index 1f1bc8c7cf7b..97c30ec2d25b 100644 --- a/docs/reference/api/python/driver.rst +++ b/docs/reference/api/python/driver.rst @@ -19,6 +19,4 @@ tvm.driver ---------- .. automodule:: tvm.driver -.. autofunction:: tvm.lower - .. autofunction:: tvm.build diff --git a/include/tvm/driver/driver_api.h b/include/tvm/driver/driver_api.h deleted file mode 100644 index 39444d1629fe..000000000000 --- a/include/tvm/driver/driver_api.h +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/driver/driver_api.h - * \brief Compiler driver APIs to drive the compilation. - * - * This module provides end-to-end utils to drive the compilation process. - * We adopt the term "compiler driver" in common compiler infrastructures. - * Note that a compiler driver is different from "runtime drivers". - * Most of runtime related code are defined in the runtime folder instead. - */ -#ifndef TVM_DRIVER_DRIVER_API_H_ -#define TVM_DRIVER_DRIVER_API_H_ - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { -using tvm::transform::Pass; - -/*! - * \brief Configures and returns the composite Pass for the fused module (pre split) that contains - * device and host code. - * \param mixed_mod The original mixed module. - * \param target The device Target. - * \return The composite Pass for the fused module. -// */ -TVM_DLL transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target); - -/*! - * \brief Configures and returns the composite Pass for the device Target after device/host from - * mixed module. - * \param mixed_mod The optimized mixed module. - * \param target The device Target. - * \return The composite Pass for the device module. - */ -TVM_DLL transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target); - -/*! - * \brief Configures and returns the composite Pass for the host Target after device/host from mixed - * module. - * \param mixed_mod The optimized mixed module. - * \param target_host The host Target. - * \return The composite Pass for the host module. - */ -TVM_DLL transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host); - -/*! - * \brief Lower an IRModule (optimize with it with the pass list defined in CreatePassList) - * \param mod The IRmodule to lower - * \param simple_mode Disables the loop partition pass. Defaults to false. - * \return The result module. - */ -TVM_DLL IRModule LowerModule(IRModule mod, bool simple_mode = false); - -/*! - * \brief Lower a primfunc and name (convert to IRModule, and optimize it with the pass list - * defined in CreatePassList) - * \param func The PrimFunc to lower - * \param name The name of the lowered function. - * \param simple_mode Disables the loop partition pass. Defaults to false. - * \return The result module. - */ -TVM_DLL IRModule LowerPrimFunc(tvm::tir::PrimFunc func, const std::string& name, - bool simple_mode = false); - -/*! - * \brief Build a device and host module for a specific target from an IRModule. - * \param funcs The functions to be built. - * \param target The target device to build for. - * \param target_host The target for building host code. To use the default, pass Target() - * \return The built module. - */ -TVM_DLL runtime::Module build(const IRModule& funcs, const Target& target, - const Target& target_host); - -/*! - * \brief Build a device and host module for a specific target from a map - * contains target to IRModule. This function is used - * for heterogeneous build. - * \param input The map contains target to an IRModule. - * \param target_host The target for building host code. To use the default, - * pass Target(). - * \return The built module that contains code for different processors. - */ -TVM_DLL runtime::Module build(const Map& input, const Target& target_host); - -/*! - * \brief Build a device and host module for a specific target from a map - * contains target to IRModule. This function is used - * for heterogeneous build. - * \param input The map contains target string to an IRModule. - * \param target_host The target for building host code. To use the default, - * pass Target(). - * \return The built module that contains code for different processors. - */ -TVM_DLL runtime::Module build(const Map& input, const Target& target_host); -} // namespace tvm - -#endif // TVM_DRIVER_DRIVER_API_H_ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index abbab3ad6d39..f4519f834d74 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -55,7 +55,7 @@ from . import te # tvm.driver -from .driver import build, lower +from .driver import build # others from . import arith diff --git a/python/tvm/driver/__init__.py b/python/tvm/driver/__init__.py index 75e94cc91c83..b97375c3a364 100644 --- a/python/tvm/driver/__init__.py +++ b/python/tvm/driver/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. """Namespace for driver APIs""" -from .build_module import lower, build +from .build_module import build diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 94006111ffa2..8d6a2a534389 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -18,130 +18,15 @@ # pylint: disable=invalid-name """The build utils in python.""" from typing import Union, Optional - - -import tvm.tir - - -from tvm.runtime import ndarray +import tvm from tvm.tir import PrimFunc from tvm.ir.module import IRModule from tvm.target import Target -from tvm.driver import _ffi_api as _driver_ffi - -from . import _ffi_api as ffi - - -def lower( - inp: Union[PrimFunc, IRModule], - name: str = "main", - simple_mode: bool = False, -) -> IRModule: - """Lowering step before build into target. - - Parameters - ---------- - inp : Union[tvm.tir.PrimFunc, IRModule] - The TE schedule or TensorIR PrimFunc/IRModule to be built - - name : str - The name of the result function. - - simple_mode : bool - Whether only output simple and compact statement, this will skip - LoopPartition, api wrapper generation and Unrolling. - - Returns - ------- - m : IRModule - The result IRModule - """ - if isinstance(inp, IRModule): - return ffi.lower_module(inp, simple_mode) - if isinstance(inp, PrimFunc): - return ffi.lower_primfunc(inp, name, simple_mode) - raise ValueError( - f"Expected input to be an IRModule, PrimFunc or te.Schedule, but got {type(inp)}" - ) def build( - inputs: Union[PrimFunc, IRModule], + mod: Union[PrimFunc, IRModule], target: Optional[Union[str, Target]] = None, - name: str = "main", + pipeline: Optional[Union[str, tvm.transform.Pass]] = "default_tir", ): - """Build a function with arguments as signature. Code will be generated - for devices coupled with target information. - - Parameters - ---------- - input : Union[tvm.tir.PrimFunc, IRModule] - The input to be built - - target : Optional[Union[str, Target]] - The target and option of the compilation. - - name : str - The name of result function. - - Returns - ------- - ret : tvm.module - A module that combines both host and device code. - - Note - ---- - See the note on :any:`tvm.target` on target string format. - """ - if isinstance(inputs, PrimFunc): - input_mod = lower(inputs, name=name) - elif isinstance(inputs, tvm.IRModule): - assert ( - len(inputs.get_global_vars()) > 0 - ), "Expected a non-empty IRModule, but the IRModule contained no functions." - input_mod = lower(inputs) - else: - raise ValueError("Inputs must be IRModule or PrimFunc") - - target = Target.current() if target is None else target - if target is None and isinstance(input_mod, tvm.IRModule): - target_mod = {} - for gvar, func in input_mod.functions.items(): - tgt = func.attrs["target"] if "target" in func.attrs else "llvm" - if tgt not in target_mod: - target_mod[tgt] = {} - target_mod[tgt][gvar] = func - - target_input_mod = {} - for tgt in target_mod.keys(): - tir_mod = tvm.IRModule(target_mod[tgt]) - tir_mod = tir_mod.with_attrs(input_mod.attrs) - target_input_mod[tgt] = tir_mod - else: - target_input_mod = {target: input_mod} - - # Because modules can be created from a variety of sources, we annotate them - # with the relevant attributes here to ensure they propagate - annotated_mods = {} - for tgt, mod in target_input_mod.items(): - if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or " "Target when inputs is dict.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be IRModule, " "or dict of str to IRModule.") - annotated_mods[tgt] = mod - - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods) - if not target_host: - for tar, mod in annotated_mods.items(): - device_type = ndarray.device(tar.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type: - target_host = tar - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - - rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host) - - return rt_mod_host + return tvm.tir.build(mod, target, pipeline) diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 1d7352f66527..9ff5bff5f1ff 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -108,3 +108,5 @@ from . import transform from . import analysis from . import stmt_functor +from .build import build +from .pipeline import get_pipeline diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py new file mode 100644 index 000000000000..cd44ed881ba3 --- /dev/null +++ b/python/tvm/tir/build.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name +"""The build utils in python.""" +from typing import Union, Optional, Dict +import enum + +import tvm +from tvm import ir +from tvm.runtime import ndarray +from tvm.tir import PrimFunc +from tvm.ir.module import IRModule +from tvm.target import Target + + +def split_host_device_mods(mod): + """Split an IRModule into host and device modules. + + Parameters + ---------- + mod : tvm.IRModule + The input module to split + + Returns + ------- + host_mod : tvm.IRModule + The module containing host functions + device_mod_dict : Dict[Target, tvm.IRModule] + A dict mapping targets to device modules + """ + + class CallConv(enum.IntEnum): + """Enum representing different calling conventions. + Corresponds to the C++ tvm::ir::CallingConv enum. + """ + + kDefault = 0 + kCPackedFunc = 1 + kDeviceKernelLaunch = 2 + + host_mod = tvm.tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + != int(CallConv.kDeviceKernelLaunch) + )(mod) + device_mod = tvm.tir.transform.Filter( + lambda f: int(f.attrs.get("calling_conv", CallConv.kDefault)) + == int(CallConv.kDeviceKernelLaunch) + )(mod) + device_mod_dict = {} + for gv, func in device_mod.functions.items(): + device_mod_dict.setdefault(func.attrs.get("target", None), dict()).update({gv: func}) + for target, funcs in device_mod_dict.items(): + device_mod_dict[target] = tvm.IRModule(funcs, attrs=device_mod.attrs) + return host_mod, device_mod_dict + + +def codegen_build(mod: IRModule, target: Target) -> tvm.runtime.Module: + """Build a runtime module from an IRModule and a Target.""" + if tvm.ir.transform.PassContext.current().config.get("tir.disable_assert", False): + mod = tvm.tir.transform.SkipAssert()(mod) + build_f_name = "target.build." + target.kind.name + bf = tvm.get_global_func(build_f_name) + if bf is None: + raise ValueError(f"{build_f_name} is not enabled") + return bf(mod, target) + + +def tir_to_runtime( + host_mod: IRModule, device_mod_dict: Dict[Target, IRModule], target_host: Target +): + """Convert a collection of TIR IRModules (keyed by Target) into a single runtime Module.""" + + # Get the first module to get the attributes + # necessary for tests/python/codegen/test_target_codegen_blob.py::test_cuda_multi_lib + mhost_all = ir.IRModule({}, attrs=host_mod.attrs) + + mhost_all.update(host_mod) + device_modules = [] + for target, device_mod in device_mod_dict.items(): + if len(device_mod.functions) != 0: + device_modules.append(codegen_build(device_mod, target)) + + mhost = codegen_build(mhost_all, target_host) + for dev_mod in device_modules: + if dev_mod is not None: + mhost.import_module(dev_mod) + return mhost + + +def build( + mod: Union[PrimFunc, IRModule], + target: Optional[Union[str, Target]] = None, + pipeline: Union[None, str, tvm.transform.Pass] = "default_tir", +): + """Build a function with a signature, generating code for devices + coupled with target information. + + Parameters + ---------- + mod : Union[PrimFunc, IRModule] + The input to be built. + target : Optional[Union[str, Target]] + The target for compilation. + pipeline : Union[None, str, tvm.transform.Pass] + The pipeline to use for compilation. + + Returns + ------- + tvm.runtime.Module + A module combining both host and device code. + """ + # Convert PrimFunc to IRModule + if isinstance(mod, PrimFunc): + mod = tvm.IRModule.from_expr(mod) + else: + assert isinstance(mod, tvm.IRModule) + + # Step 0: Determine the target in environment + target = Target.current() if target is None else target + if target is None: + target = "llvm" + assert target is not None + target = Target.canon_target(target) + + # Step 1: Determine the host + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + if target is not None: + if target.host is not None: + target_host = target.host + elif ndarray.device(target.kind.name, 0).device_type == ndarray.cpu(0).device_type: + target_host = target + else: + for func in mod.functions.values(): + f_target = func.attrs.get("target", None) + if f_target is not None and f_target.host is not None: + target_host = f_target.host + assert target_host is not None + target_host = Target.canon_target(target_host) + target = target.with_host(target_host) + + # Step 2: Bind the target to the input module + mod = tvm.tir.transform.BindTarget(target)(mod) + + # Step 3: Apply the pipeline + if pipeline is not None: + if isinstance(pipeline, str): + pipeline = tvm.tir.get_pipeline(pipeline) + mod = pipeline(mod) + + # Step 4: Get host and device modules + host_mod, device_mod_dict = split_host_device_mods(mod) + + # Step 5: Apply finalization passes + host_mod = tvm.tir.pipeline.finalize_host_passes()(host_mod) + device_mod_dict = { + target: tvm.tir.pipeline.finalize_device_passes()(device_mod) + for target, device_mod in device_mod_dict.items() + } + + # Convert TIR IRModules to runtime Module by calling target.build + return tir_to_runtime(host_mod, device_mod_dict, target_host) + + +tvm.register_func("tir.build", build) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py new file mode 100644 index 000000000000..0b6d622c90e1 --- /dev/null +++ b/python/tvm/tir/pipeline.py @@ -0,0 +1,175 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name +"""The TIR backend compilation pipeline.""" + +import tvm +from tvm import tir + + +def default_tir_pipeline(): + """The default tir pipeline used in tvm.tir.build""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + """The default lowering passes for TIR backend.""" + pass_ctx = tvm.transform.PassContext.current() + config = pass_ctx.config + passes = [ + tir.transform.InjectPrefetch(), + tir.transform.TextureFlatten(), + tir.transform.StorageFlatten( + 64, bool(config.get("tir.instrument_bound_checkers", False)) + ), + tir.transform.LowerCrossThreadReduction(), + tir.transform.LowerInitBlock(), + tir.transform.PlanAndUpdateBufferAllocationLocation(), + tir.transform.ConvertBlocksToOpaque(), + tir.transform.LiftThreadBinding(), + tir.transform.ManifestSharedMemoryLocalStage(), + tir.transform.CompactBufferAllocation(), + tir.transform.LowerAutoCopy(), + tir.transform.UnifyThreadBinding(), + tir.transform.LowerMatchBuffer(), + tir.transform.Simplify(), + tir.transform.InjectPermutedLayout(), + tir.transform.InjectSoftwarePipeline(), + tir.transform.TransformMmaBufferLayout(), + tir.transform.LowerOpaqueBlock(), + tir.transform.FlattenBuffer(), + tir.transform.BF16ComputeLegalize(), + tir.transform.NarrowDataType(32), + tir.transform.LoopPartition(), + tir.transform.VectorizeLoop(not bool(config.get("tir.disable_vectorize", False))), + tir.transform.InjectVirtualThread(), + tir.transform.InjectDoubleBuffer(), + ] + if not bool(config.get("tir.disable_storage_rewrite", False)): + passes.append(tir.transform.StorageRewrite()) + if config.get("tir.use_async_copy", False): + passes.append(tir.transform.LowerAsyncDMA()) + passes.extend( + [ + tir.transform.HoistIfThenElse(), + tir.transform.UnrollLoop(), + tir.transform.RenormalizeSplitPattern(), + tir.transform.Simplify(), + tir.transform.RemoveNoOp(), + tir.transform.RewriteUnsafeSelect(), + ] + ) + # Additional passes based on configuration. + if bool(config.get("tir.instrument_bound_checkers", False)): + passes.append(tir.transform.InstrumentBoundCheckers()) + if bool(config.get("tir.ptx_ldg32", False)): + passes.append(tir.transform.InjectPTXLDG32(True)) + passes.append( + tir.transform.CommonSubexprElimTIR( + not bool(config.get("tir.disable_cse_tir", False)), + bool(config.get("tir.enable_equiv_terms_in_cse_tir", False)), + ) + ) + if bool(config.get("tir.instrument_lwp", False)): + passes.append(tir.transform.InstrumentProfileIntrinsics()) + passes.extend( + [ + # Bind the target first so that target-specific attributes are available. + tir.transform.FP8ComputeLegalize(), + # VerifyVTCMLimit must occur before LowerVtcmAlloc. + tir.transform.VerifyVTCMLimit(), + tir.transform.LowerVtcmAlloc(), + tir.transform.VerifyMemory(), + tir.transform.AnnotateEntryFunc(), + ] + ) + if bool(config.get("tir.detect_global_barrier", False)): + passes.append(tir.transform.ThreadSync("global")) + passes.extend( + [ + tir.transform.ThreadSync("shared"), + tir.transform.ThreadSync("shared.dyn"), + tir.transform.ThreadSync("warp"), + tir.transform.InferFragment(), + tir.transform.LowerThreadAllreduce(), + ] + ) + if bool(config.get("tir.use_async_copy", False)): + passes.append(tir.transform.InjectPTXAsyncCopy()) + if bool(config.get("tir.ptx_ldg32", False)): + passes.append(tir.transform.InjectPTXLDG32()) + passes.extend( + [ + tir.transform.AnnotateDeviceRegions(), + tir.transform.SplitHostDevice(), + # MergeSharedMemoryAllocations must follow SplitHostDevice. + tir.transform.MergeSharedMemoryAllocations(), + tir.transform.MakePackedAPI(), + tir.transform.FP8StorageLegalize(), + tir.transform.BF16StorageLegalize(), + tir.transform.LowerDeviceKernelLaunch(), + ] + ) + mod = tvm.ir.transform.Sequential(passes)(mod) + return mod + + return _pipeline + + +def finalize_host_passes(): # pylint: disable=unused-argument + """The default finalization passes for TIR backend.""" + host_pass_list = [ + tir.transform.LowerTVMBuiltin(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerIntrin(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.CombineContextCall(), + ] + return tvm.ir.transform.Sequential(host_pass_list) + + +def finalize_device_passes(): # pylint: disable=unused-argument + """The default finalization passes for TIR backend.""" + device_pass_list = [ + tir.transform.LowerWarpMemory(), + tir.transform.Simplify(), + tir.transform.LowerCustomDatatypes(), + tir.transform.LowerDeviceStorageAccessInfo(), + tir.transform.LowerIntrin(), + ] + return tvm.ir.transform.Sequential(device_pass_list) + + +# global map of pre-built pipelines +PIPELINE_MAP = { + "default_tir": default_tir_pipeline, +} + + +def get_pipeline(name: str = "default_tir", **kwargs) -> tvm.transform.Pass: + """Get pre-build pipeline by name + + Parameters + ---------- + name : Optional[str] + Name of the pipeline + """ + if name not in PIPELINE_MAP: + raise ValueError( + f"Unknown pre-built pipeline {name}," f"candidates are {list(PIPELINE_MAP.keys())}" + ) + return PIPELINE_MAP[name](**kwargs) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index b08659e1c712..99a2e1e66485 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -713,7 +713,7 @@ def VerifyMemory(): return _ffi_api.VerifyMemory() # type: ignore -def VerifyVTCMLimit(limit: int): +def VerifyVTCMLimit(limit=None): """Verify if the size of the allocated vtcm memory satisfies the limit. Returns @@ -1200,3 +1200,36 @@ def UseAssumeToReduceBranches(): The result pass """ return _ffi_api.UseAssumeToReduceBranches() # type: ignore + + +def LowerAsyncDMA(): + """Lower async DMA to DMA. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerAsyncDMA() # type: ignore + + +def InjectPTXLDG32(enable_inject_ptx_intrin: bool = True): + """Inject ptx.ldg.32 intrinsics. + + Parameters + ---------- + enable_inject_ptx_intrin : bool + If True, inject ptx.ldg.32 intrinsics. + """ + return _ffi_api.InjectPTXLDG32(enable_inject_ptx_intrin) # type: ignore + + +def LowerVtcmAlloc(): + """Lower vtcm allocation. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.LowerVtcmAlloc() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc deleted file mode 100644 index 5b12f13d96a6..000000000000 --- a/src/driver/driver_api.cc +++ /dev/null @@ -1,595 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Compile executable modules. - * \file driver_api.cc - */ -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace tvm { - -// Register build pipeline related options -TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); - -// WARNING: May cause coherency issues resulting data miscompares -// Experimental feature that, when enabled by the runtime, bypasses the cache when using DMA. When -// bypassing the cache TVM must manage cache coherency in software. Software managed cache coherency -// can be tricky e.g. it is yet to be proven out in the Hexagon runtime. Hence the warning above and -// the "experimental" notation for this feature. -TVM_REGISTER_PASS_CONFIG_OPTION("tir.experimental_dma_bypass_cache", Bool); - -using tvm::Array; -using tvm::transform::Pass; - -bool LLVMEnabled() { - const runtime::PackedFunc* pf = runtime::Registry::Get("target.build.llvm"); - return pf != nullptr; -} - -/*! \return The default host target for a given device target */ -Target DefaultTargetHost(Target target) { - if (target.defined() && target->GetTargetDeviceType() == kDLCPU) { - return target; - } else { - if (LLVMEnabled()) { - return Target("llvm"); - } else { - return Target("stackvm"); - } - } -} - -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list) { - *out_binds = binds; - - for (const ObjectRef& x : args) { - if (auto tensor_node = x.as()) { - te::Tensor x_ref = tensor_node.value(); - if (out_binds->find(x_ref) == out_binds->end()) { - tir::Buffer buf = tir::BufferWithOffsetAlignment(x_ref->shape, x_ref->dtype, - x_ref->op->name, -1, 0, compact); - out_binds->Set(x_ref, buf); - out_arg_list->push_back(buf); - } else { - out_arg_list->push_back((*out_binds)[x_ref]); - } - } else if (x.as() || x.as()) { - out_arg_list->push_back(x); - } else { - LOG(FATAL) - << "Expected type of the elements of args to be te::Tensor, te::Buffer or tir::Var, " - << "but got a " << x->GetTypeKey(); - } - } -} - -void GetBinds(const Array& args, bool compact, - const std::unordered_map& binds, - Map* out_binds, Array* out_arg_list) { - Array ref_args; - for (ObjectRef x : args) { - ref_args.push_back(x); - } - GetBinds(ref_args, compact, binds, out_binds, out_arg_list); -} - -TVM_REGISTER_GLOBAL("driver.get_binds") - .set_body_typed([](const Array& args, bool compact, - const Map& binds) { - std::unordered_map c_binds; - // Check to make sure binds is not null before doing the conversion; - if (binds.get() != nullptr) { - for (auto kv : binds) { - c_binds.insert({kv.first, kv.second}); - } - } - Map out_binds; - Array out_arg_list; - GetBinds(args, compact, c_binds, &out_binds, &out_arg_list); - - // TVM object system doesn't have a pair object, so we'll put both ret values in an array - // and return that. - Array out_arr = {out_binds, out_arg_list}; - return out_arr; - }); - -Array CreatePassList(bool disable_loop_partition) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - - bool disable_vectorize = pass_ctx->GetConfig("tir.disable_vectorize", Bool(false)).value(); - bool disable_storage_rewrite = - pass_ctx->GetConfig("tir.disable_storage_rewrite", Bool(false)).value(); - bool instrument_bound_checkers = - pass_ctx->GetConfig("tir.instrument_bound_checkers", Bool(false)).value(); - bool disable_cse_tir = pass_ctx->GetConfig("tir.disable_cse_tir", Bool(false)).value(); - bool enable_equiv_terms_in_cse_tir = - pass_ctx->GetConfig("tir.enable_equiv_terms_in_cse_tir", Bool(false)).value(); - - bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - - // Get any user-added passes - Array> add_lower_pass = - pass_ctx->GetConfig>>("tir.add_lower_pass", Array>()) - .value(); - - bool instrument_lwp = pass_ctx->GetConfig("tir.instrument_lwp", Bool(false)).value(); - - Array user_lower_phase0 = Array(); - Array user_lower_phase1 = Array(); - Array user_lower_phase2 = Array(); - Array user_lower_phase3 = Array(); - - // phase passes is of the form - // [[phase_number, pass], [phase_number, pass]... ] - for (Array phase_pass : add_lower_pass) { - auto phase_num = phase_pass[0].as(); - ICHECK(phase_num) - << "Expected the first entry in the inner Array of tir.add_lower_pass to be an integer, " - << "but instead received " << phase_pass[0] << " with type " << phase_pass[0]->GetTypeKey(); - int phase_num_val = phase_num->value; - - CHECK_GE(phase_num_val, 0); - - auto pass = Downcast(phase_pass[1]); - // Copy the pass into the correct phase - if (phase_num_val == 0) { - user_lower_phase0.push_back(pass); - } else if (phase_num_val == 1) { - user_lower_phase1.push_back(pass); - } else if (phase_num_val == 2) { - user_lower_phase2.push_back(pass); - } else if (phase_num_val >= 3) { - user_lower_phase3.push_back(pass); - } - } - - // Construct the pass list, inserting the user provided passes at the end of the phase - - // PHASE 0 - Array pass_list = user_lower_phase0; - - // PHASE 1 - pass_list.push_back(tir::transform::InjectPrefetch()); - pass_list.push_back(tir::transform::TextureFlatten()); - pass_list.push_back(tir::transform::StorageFlatten(64, instrument_bound_checkers)); - pass_list.push_back(tir::transform::LowerCrossThreadReduction()); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::LiftThreadBinding()); - pass_list.push_back(tir::transform::ManifestSharedMemoryLocalStage()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerAutoCopy()); - pass_list.push_back(tir::transform::UnifyThreadBinding()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::InjectPermutedLayout()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::TransformMmaBufferLayout()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::BF16ComputeLegalize()); - pass_list.push_back(tir::transform::NarrowDataType(32)); - pass_list.push_back(tir::transform::Simplify()); - - // Add user-defined phase-1 passes - pass_list.insert(pass_list.end(), user_lower_phase1.begin(), user_lower_phase1.end()); - - // PHASE 2 - if (!disable_loop_partition) { - pass_list.push_back(tir::transform::LoopPartition()); - } - - pass_list.push_back(tir::transform::VectorizeLoop(!disable_vectorize)); - pass_list.push_back(tir::transform::InjectVirtualThread()); - pass_list.push_back(tir::transform::InjectDoubleBuffer()); - if (!disable_storage_rewrite) { - pass_list.push_back(tir::transform::StorageRewrite()); - } - bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); - - if (use_async_copy) { - pass_list.push_back(tir::transform::LowerAsyncDMA()); - } - // HoistIfThenElse must be applied before UnrollLoop - // because HoistIfThenElse could utilize for loop structure - // which might be unrolled in UnrollLoop - pass_list.push_back(tir::transform::HoistIfThenElse()); - pass_list.push_back(tir::transform::UnrollLoop()); - - // Add user-defined phase-2 passes - pass_list.insert(pass_list.end(), user_lower_phase2.begin(), user_lower_phase2.end()); - - // PHASE 3 - pass_list.push_back(tir::transform::RenormalizeSplitPattern()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::RemoveNoOp()); - pass_list.push_back(tir::transform::RewriteUnsafeSelect()); - - // Add user-defined phase-3 passes - pass_list.insert(pass_list.end(), user_lower_phase3.begin(), user_lower_phase3.end()); - - if (instrument_bound_checkers) { - pass_list.push_back(tir::transform::InstrumentBoundCheckers()); - } - - if (ptx_ldg32) { - pass_list.push_back(tir::transform::InjectPTXLDG32(true)); - } - - pass_list.push_back( - tir::transform::CommonSubexprElimTIR(!disable_cse_tir, enable_equiv_terms_in_cse_tir)); - - // This pass instruments the loops with the profile builtin calls to capture the runtime - // performance data (only enabled for Hexagon at the moment). To ensure that no other - // optimizations are performed on the instrumented code, this pass must be added at the end - // of the list. - if (instrument_lwp) { - pass_list.push_back(tir::transform::InstrumentProfileIntrinsics()); - } - - return pass_list; -} - -IRModule LowerWithPassList(IRModule mod, Array pass_list) { - auto optimize = tvm::transform::Sequential(pass_list); - mod = optimize(std::move(mod)); - return mod; -} - -IRModule ApplyPasses(IRModule mod, transform::Sequential seq) { - mod = seq(std::move(mod)); - return mod; -} - -IRModule LowerModule(IRModule mod, bool simple_mode) { - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); -} - -TVM_REGISTER_GLOBAL("driver.lower_module").set_body_typed([](IRModule mod, bool simple_mode) { - return LowerModule(std::move(mod), simple_mode); -}); - -IRModule LowerPrimFunc(tir::PrimFunc func, const std::string& name, bool simple_mode) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = WithAttr(std::move(func), "global_symbol", runtime::String(name)); - - bool noalias = pass_ctx->GetConfig("tir.noalias", Bool(true)).value(); - - if (noalias) { - f = WithAttr(std::move(f), "tir.noalias", Bool(true)); - } - IRModule mod = IRModule(Map({{GlobalVar(name), f}})); - - // Get the pass list - Array pass_list = CreatePassList(simple_mode); - return LowerWithPassList(std::move(mod), pass_list); -} - -TVM_REGISTER_GLOBAL("driver.lower_primfunc") - .set_body_typed([](te::PrimFunc func, const String& name, bool simple_mode) { - return LowerPrimFunc(std::move(func), name, simple_mode); - }); - -/** - * This function takes the input module that contains both the device and host opts. - * Then, it applies transformation on the original module before splitting into separate modules for - * device and host. Then it also applies transformations on the new splitted modules. - */ -std::pair SplitMixedModule(IRModule mod_mixed, const Target& target_arg, - const Target& target_host_arg) { - Target target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - - ICHECK(mod_mixed.defined()) << "This module must be defined"; - - mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target)); - - IRModule host_mod = ApplyPasses(mod_mixed, HostModulePassManager(mod_mixed, target_host)); - - IRModule device_mod = ApplyPasses(mod_mixed, DeviceModulePassManager(mod_mixed, target)); - - auto keys = target->GetKeys(); - - CheckAndUpdateHostConsistency(&target, &target_host); - - bool target_is_gpu = std::find(keys.begin(), keys.end(), "gpu") != keys.end(); - if (target_is_gpu && device_mod->functions.size() == 0) { - DLOG(WARNING) << "Specified target " << target->str() - << " but cannot find device code. Did you forget to bind?"; - } - - return {host_mod, device_mod}; -} - -/*! - * \brief Check and update host field of the given legacy heterogeneous targets and - * target host.Note that this function is for legacy target api compatibility issue only, - * not recommended for other use. - * \param ir_modules The pointer to a Map objects with keys being Target objects - * \param host The Target typed object for target host to be updated - */ -void CheckAndUpdateHostConsistency(Map* targets, Target* host) { - Map new_targets; - for (auto& it : *targets) { - auto target = it.first; - CheckAndUpdateHostConsistency(&target, host); - new_targets.Set(target, it.second); - } - *targets = new_targets; -} - -runtime::Module TIRToRuntime(const Map& inputs_arg, - const Target& target_host_arg) { - CHECK(inputs_arg.size()) << "TIRToRuntime expects at least one IRModule as input."; - std::vector device_modules; - Map inputs = inputs_arg; - Target target_host = target_host_arg; - - // Fetch previous defined target host in targets - CheckAndUpdateHostConsistency(&inputs, &target_host); - - if (!target_host.defined()) { - for (const auto& it : inputs) { - if (it.first->GetTargetDeviceType() == kDLCPU) { - target_host = it.first; - break; - } - } - } - - if (!target_host.defined()) { - target_host = DefaultTargetHost(target_host); - } - - // Update target host for all targets - CheckAndUpdateHostConsistency(&inputs, &target_host); - - // Take the attrs from the first module so the eventual modules have them. - // Ideally this would just be one unified module all the way through; - IRModule first_module = (*inputs.begin()).second; - IRModule mhost_all = IRModule(Map(), {}, first_module->attrs); - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - for (const auto& it : inputs) { - if (it.second.defined()) { - const Target& target = it.first; - const IRModule& ir_module = it.second; - auto pair = SplitMixedModule(ir_module, target, target_host); - auto& host_mod = pair.first; - auto& device_mod = pair.second; - - ICHECK(host_mod.defined()) << "The split host module must be defined"; - - ICHECK(mhost_all.defined()) << "The host module must be defined"; - - // We don't want library modules going back into host codegen - // unless they're supposed to. Here if we overrode the target host - // to allow lowering previously we check that it's meant to be placed - // back into the host Module. - bool overrides_host_target = - target->GetTargetDeviceType() == target_host->GetTargetDeviceType(); - bool non_host_target_kind = target->kind != target_host->kind; - if (overrides_host_target && non_host_target_kind) { - device_modules.push_back(codegen::Build(host_mod, it.first)); - } else { - mhost_all->Update(host_mod); - } - - if (device_mod->functions.size() != 0) { - device_modules.push_back(codegen::Build(device_mod, it.first)); - } - } - } - - runtime::Module mhost = codegen::Build(mhost_all, target_host); - for (const auto& it : device_modules) { - if (it.operator->()) { - mhost.Import(it); - } - } - - return mhost; -} - -TVM_REGISTER_GLOBAL("driver.tir_to_runtime") - .set_body_typed([](const Map& inputs_arg, Target host_target) { - return TIRToRuntime(inputs_arg, host_target); - }); - -// Build for heterogeneous execution when targets are specified as -// objects. This wrapper around the internal API is maintained for -// backwards compatibility. -runtime::Module build(const Map& input, const Target& target_host) { - return TIRToRuntime(input, target_host); -} - -// Build for heterogeneous execution when target is a string. -runtime::Module build(const Map& inputs_arg, const Target& target_host_arg) { - Map updated_inputs; - Target target_host = target_host_arg; - for (const auto& it : inputs_arg) { - Target target = Target(it.first); - CheckAndUpdateHostConsistency(&target, &target_host); - Optional device = target->GetAttr("device"); - if (device.defined() && device.value() == "vta") { - target = Target("ext_dev"); - } - updated_inputs.Set(target, it.second); - } - return TIRToRuntime(updated_inputs, target_host); -} - -// Build for homogeneous execution. -runtime::Module build(const IRModule& funcs, const Target& target_arg, - const Target& target_host_arg) { - auto target = target_arg, target_host = target_host_arg; - CheckAndUpdateHostConsistency(&target, &target_host); - // More maps of target and target host - Map inputs = {{target, funcs}}; - return TIRToRuntime(inputs, target_host); -} - -transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - - Array mixed_pass_list; - - // FPComputeLegalize uses the target attrs added by BindTarget, so it must come first - mixed_pass_list.push_back(tir::transform::BindTarget(target)); - mixed_pass_list.push_back(tir::transform::FP8ComputeLegalize()); - - // VerifyVTCMLimit must occur before LowerVtcmAlloc - mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(target)); - // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations - mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); - - mixed_pass_list.push_back(tir::transform::VerifyMemory()); - - mixed_pass_list.push_back(tir::transform::AnnotateEntryFunc()); - - bool detect_global_barrier = - pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value(); - if (detect_global_barrier) { - mixed_pass_list.push_back(tir::transform::ThreadSync("global")); - } - - mixed_pass_list.push_back(tir::transform::ThreadSync("shared")); - mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn")); - mixed_pass_list.push_back(tir::transform::ThreadSync("warp")); - mixed_pass_list.push_back(tir::transform::InferFragment()); - mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); - - bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); - - if (use_async_copy) { - mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy()); - } - - bool ptx_ldg32 = pass_ctx->GetConfig("tir.ptx_ldg32", Bool(false)).value(); - if (ptx_ldg32) { - mixed_pass_list.push_back(tir::transform::InjectPTXLDG32()); - } - - mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions()); - mixed_pass_list.push_back(tir::transform::SplitHostDevice()); - // MergeSharedMemoryAllocations must be applied after SplitHostDevice - // because the merged allocation site is at the beginning of each device function - mixed_pass_list.push_back(tir::transform::MergeSharedMemoryAllocations()); - - mixed_pass_list.push_back(tir::transform::MakePackedAPI()); - mixed_pass_list.push_back(tir::transform::FP8StorageLegalize()); - mixed_pass_list.push_back(tir::transform::BF16StorageLegalize()); - - mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch()); - - return transform::Sequential(mixed_pass_list); -} - -TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target) { - return MixedModulePassManager(mixed_mod, target); - }); - -transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { - transform::PassContext pass_ctx = transform::PassContext::Current(); - - Array host_pass_list; - - runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) != - CallingConv::kDeviceKernelLaunch; - }; - host_pass_list.push_back(tir::transform::Filter(fcond)); - - ICHECK(mixed_mod.defined()) << "This module must be defined"; - - host_pass_list.push_back(tir::transform::BindTarget(target_host)); - - host_pass_list.push_back(tir::transform::LowerTVMBuiltin()); - host_pass_list.push_back(tir::transform::LowerCustomDatatypes()); - host_pass_list.push_back(tir::transform::LowerIntrin()); - host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); - host_pass_list.push_back(tir::transform::CombineContextCall()); - - return transform::Sequential(host_pass_list); -} - -TVM_REGISTER_GLOBAL("driver.host_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target_host) { - return HostModulePassManager(mixed_mod, target_host); - }); - -transform::Sequential DeviceModulePassManager(IRModule mixed_mod, Target target) { - Array device_pass_list; - runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { - return f->GetAttr(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) == - CallingConv::kDeviceKernelLaunch; - }; - device_pass_list.push_back(tir::transform::Filter(fcond)); - - device_pass_list.push_back(tir::transform::BindTarget(target)); - - device_pass_list.push_back(tir::transform::LowerWarpMemory()); - device_pass_list.push_back(tir::transform::Simplify()); - device_pass_list.push_back(tir::transform::LowerCustomDatatypes()); - device_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); - device_pass_list.push_back(tir::transform::LowerIntrin()); - - return transform::Sequential(device_pass_list); -} - -TVM_REGISTER_GLOBAL("driver.device_mod_passes") - .set_body_typed([](IRModule mixed_mod, Target target_host) { - return DeviceModulePassManager(mixed_mod, target_host); - }); - -} // namespace tvm diff --git a/src/driver/internal_driver_api.h b/src/driver/internal_driver_api.h deleted file mode 100644 index 3b7cc7c7f7fa..000000000000 --- a/src/driver/internal_driver_api.h +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/driver/driver_api.h - * \brief Internal compiler driver APIs to drive the compilation. - * - * This module provides functionality that may be called internally - * within TVM, but is not part of the public-facing API. - */ -#ifndef TVM_DRIVER_INTERNAL_DRIVER_API_H_ -#define TVM_DRIVER_INTERNAL_DRIVER_API_H_ - -#include -#include - -namespace tvm { - -/*! - * \brief Build a device and host module for a specific target from a map - * contains target to IRModule. This function is used - * for heterogeneous build. - * \param input The map contains target to an IRModule. - * \param target_host The target for building host code. To use the default, - * pass Target(). - * \return The built module that contains code for different processors. - */ -runtime::Module TIRToRuntime(const Map& input, const Target& target_host); - -} // namespace tvm - -#endif // TVM_DRIVER_INTERNAL_DRIVER_API_H_ diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 8c0ddeb6c34d..18da88be805d 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -21,7 +21,6 @@ * \file src/relax/backend/vm/codegen_vm.cc * \brief A codegen to generate VM executable from a Relax IRModule. */ -#include #include #include #include diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index a92cf7c749a0..e3812ea8c101 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -21,7 +21,6 @@ * \file src/relax/backend/vm/codegen_tir.cc * \brief A codegen to generate VMTIR function(that can be compiled) from executable. */ -#include #include #include #include diff --git a/src/relax/transform/bind_params.cc b/src/relax/transform/bind_params.cc index 27931b601760..14f68da3e4c1 100644 --- a/src/relax/transform/bind_params.cc +++ b/src/relax/transform/bind_params.cc @@ -17,7 +17,6 @@ * under the License. */ -#include #include #include #include diff --git a/src/relax/transform/fold_constant.cc b/src/relax/transform/fold_constant.cc index ff193acf143e..fb6a01a19d7f 100644 --- a/src/relax/transform/fold_constant.cc +++ b/src/relax/transform/fold_constant.cc @@ -17,7 +17,6 @@ * under the License. */ -#include #include #include #include @@ -116,8 +115,10 @@ class ConstantFolder : public ExprMutator { // already scheduled to only work on GPU, we will need to skip this in the const folder for // now // TODO(Hongyi): further check and narrow the scope of foldable function - runtime::Module rt_module = - build(LowerPrimFunc(func, "tir_function"), eval_cpu_target, eval_cpu_target); + auto* pf = runtime::Registry::Get("tir.build"); + ICHECK(pf != nullptr) << "Cannot find tir.build in registry"; + func = WithAttr(func, tvm::attr::kGlobalSymbol, String("tir_function")); + runtime::Module rt_module = (*pf)(func, eval_cpu_target); build_func = rt_module.GetFunction("tir_function"); } catch (const tvm::Error& err) { // build failure may happen in which case we skip diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index fbc43a00cad7..1c77219d453e 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -29,6 +29,26 @@ namespace tvm { namespace tir { namespace transform { +// Register build pipeline related options +TVM_REGISTER_PASS_CONFIG_OPTION("tir.noalias", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.detect_global_barrier", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_buffer_level_predication", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_static_smem", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool); + /*! * \brief Function level pass that applies transformations to all * TIR functions within the module. diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index 7f45fee9a26c..d5946fda216f 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -22,7 +22,6 @@ * \brief Passes that serve as helper functions. */ -#include #include namespace tvm { diff --git a/tests/python/codegen/test_target_codegen_cuda.py b/tests/python/codegen/test_target_codegen_cuda.py index ae3173a14dee..b3cad9acd38e 100644 --- a/tests/python/codegen/test_target_codegen_cuda.py +++ b/tests/python/codegen/test_target_codegen_cuda.py @@ -766,6 +766,7 @@ def func3(A: T.Buffer((4, 4), "float32")) -> None: tvm.build(mod, target="cuda") +@tvm.testing.requires_cuda def test_invalid_reinterpret(): @T.prim_func def func(A: T.Buffer((4,), "uint32"), B: T.Buffer((4,), "uint8")) -> None: diff --git a/tests/python/codegen/test_target_codegen_llvm.py b/tests/python/codegen/test_target_codegen_llvm.py index e3ccff49ba1b..304c79559cbb 100644 --- a/tests/python/codegen/test_target_codegen_llvm.py +++ b/tests/python/codegen/test_target_codegen_llvm.py @@ -42,7 +42,7 @@ def test_llvm_intrin(): body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "prefetch")) - fcode = tvm.build(mod, None, "llvm") + fcode = tvm.build(mod, None) @tvm.testing.requires_llvm @@ -54,7 +54,7 @@ def test_llvm_void_intrin(): ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) - fcode = tvm.build(mod, None, "llvm") + fcode = tvm.build(mod, None) @tvm.testing.requires_llvm @@ -106,7 +106,7 @@ def test_llvm_lookup_intrin(): ib.emit(x) body = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main")) - fcode = tvm.build(mod, None, "llvm") + fcode = tvm.build(mod, None) @tvm.testing.requires_llvm @@ -557,9 +557,6 @@ def _show_info(): print("dtype: {}".format(dtype)) print("dividend range: [{}, {}]".format(start, end)) print("divisor range: [{}, {}]".format(dstart, dend)) - lowered = tvm.lower(sch.mod, simple_mode=True) - print("Lowered code:") - print(lowered) # Check that the computed values are correct for i in range(start, end + 1): @@ -764,44 +761,6 @@ def check_llvm_ir(): check_llvm_ir() -@tvm.testing.requires_llvm -def test_llvm_shuffle(): - a = te.placeholder((8,), "int32") - b = te.placeholder((8,), "int32") - c = te.compute((8,), lambda x: a[x] + b[7 - x]) - - # Convert to TIR and create schedule - mod = te.create_prim_func([a, b, c]) - sch = tir.Schedule(mod) - - def my_vectorize(): - def vectorizer(op): - store = op.body - idx = tvm.tir.Ramp(tvm.tir.const(0, "int32"), tvm.tir.const(1, "int32"), 8) - value = store.value - b_idx = tvm.tir.Shuffle([idx], [tvm.tir.const(i, "int32") for i in range(7, -1, -1)]) - new_a = tvm.tir.BufferLoad(value.a.buffer, [idx]) - new_b = tvm.tir.BufferLoad(value.b.buffer, [b_idx]) - value = new_a + new_b - return tvm.tir.BufferStore(store.buffer, new_a + new_b, [idx]) - - def _transform(f, *_): - return f.with_body( - tvm.tir.stmt_functor.ir_transform(f.body, None, vectorizer, ["tir.For"]) - ) - - return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize") - - with tvm.transform.PassContext(config={"tir.add_lower_pass": [(1, my_vectorize())]}): - ir = tvm.lower(sch.mod, simple_mode=True) - module = tvm.build(sch.mod) - a_ = tvm.nd.array(np.arange(1, 9, dtype="int32")) - b_ = tvm.nd.array(np.arange(8, 0, -1, dtype="int32")) - c_ = tvm.nd.array(np.zeros((8,), dtype="int32")) - module(a_, b_, c_) - tvm.testing.assert_allclose(c_.numpy(), (a_.numpy() * 2).astype("int32")) - - def np_float2np_bf16(arr): """Convert a numpy array of float to a numpy array of bf16 in uint16""" diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py index 99fc6ac074c2..169d868b5479 100644 --- a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py +++ b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py @@ -244,15 +244,6 @@ def apply_transform(block, buffer_name, layout): return [sch.mod] - @tvm.testing.fixture - def ir_module(self, schedule_args): - # If the two buffers are accessed with the same indices, CSE - # will replace them with a Let binding. Since this makes it - # harder to test what the transformed indices are, disabling - # the CSE pass for this test. - with tvm.transform.PassContext(disabled_pass=["tir.CommonSubexprElimTIR"]): - return tvm.lower(*schedule_args) - @tvm.testing.fixture def uses_unsupported_physical_dimensions( # pylint: disable=invalid-name self, target_host, input_layout, working_layout, output_layout @@ -291,9 +282,6 @@ def test_cache_shape(self, ir_module, input_layout, working_layout, output_layou assert len(buffer.shape) == expected_physical_dimensions - def test_lower(self, schedule_args): - assert tvm.lower(*schedule_args) - @requires_hexagon_toolchain def test_build(self, schedule_args, target_host, input_layout, working_layout, output_layout): """Testing build success/failure diff --git a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py index a927532c8f4a..f0cefa3fe256 100644 --- a/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py +++ b/tests/python/contrib/test_hexagon/test_benchmark_elemwise_add.py @@ -199,12 +199,6 @@ def _benchmark_hexagon_elementwise_add_kernel( try: ns_tir_module = _get_irmod_elemwise_add(shape, dtype, mem_scope) - # Dump the primfunc NS-TIR (as text) to the log file... - lowered_mod = tvm.lower(ns_tir_module, _PRIMFUNC_NAME) - log_file.write("LOWERED IR MODULE:\n") - log_file.write(str(lowered_mod)) - log_file.write("\n") - # Lower the primfunc's IRModule to Hexagon object code... input1 = tvm.te.placeholder(shape, dtype=dtype) input2 = tvm.te.placeholder(shape, dtype=dtype) diff --git a/tests/python/contrib/test_hexagon/test_meta_schedule.py b/tests/python/contrib/test_hexagon/test_meta_schedule.py index 26acedb88e21..c0c7355a9afa 100644 --- a/tests/python/contrib/test_hexagon/test_meta_schedule.py +++ b/tests/python/contrib/test_hexagon/test_meta_schedule.py @@ -156,7 +156,7 @@ def schedule_dense(sch, block, m_size, do_tune): def verify_dense(sch, target, m_size, n_size, k_size, hexagon_session): """Verify dense operator.""" - f = tvm.build(sch.mod["main"], target=target, name="dense") + f = tvm.build(sch.mod["main"], target=target) mod = hexagon_session.load_module(f) dev = hexagon_session.device diff --git a/tests/python/contrib/test_hexagon/test_sigmoid.py b/tests/python/contrib/test_hexagon/test_sigmoid.py index cc633795c217..1247d9075972 100644 --- a/tests/python/contrib/test_hexagon/test_sigmoid.py +++ b/tests/python/contrib/test_hexagon/test_sigmoid.py @@ -92,7 +92,7 @@ def test_sigmoid( func_name = "sigmoid" with tvm.transform.PassContext(opt_level=3): - runtime_module = tvm.build(tir_s.mod, target=get_hexagon_target("v69"), name=func_name) + runtime_module = tvm.build(tir_s.mod, target=get_hexagon_target("v69")) assert "hvx_sigmoid" in runtime_module.get_source("asm") assert "vmin" in runtime_module.get_source("asm") diff --git a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py index 498e29e407b4..d45b35befd11 100644 --- a/tests/python/contrib/test_hexagon/test_software_pipeline_async.py +++ b/tests/python/contrib/test_hexagon/test_software_pipeline_async.py @@ -183,7 +183,6 @@ def test_async_software_pipeline( "tir.experimental_dma_bypass_cache": 1, } ): - # tvm.lower(schedule.mod["main"]).show() func = tvm.build(schedule.mod["main"], target=get_hexagon_target("v68")) with hexagon_launcher.create_session() as hexagon_session: diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index cfeb70b96388..718cf3a663e5 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -38,7 +38,7 @@ def func(a: T.handle, b: T.handle) -> None: B[vi, vj, vk, vl] = A[vi, vj, vk, vl] * 2.0 with tvm.transform.PassContext(opt_level=3, instruments=[PrintBeforeAll(), PrintAfterAll()]): - tvm.lower(func) + tvm.build(func) all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py deleted file mode 100644 index edb3ed351e5d..000000000000 --- a/tests/python/tir-base/test_lower_build.py +++ /dev/null @@ -1,133 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import numpy as np - -import tvm -from tvm.ir.module import IRModule -from tvm.script import tir as T -import tvm.testing - - -def _check_module_with_numpy(mod, shape=(128, 128, 128)): - m, n, k = shape - a = tvm.nd.array(np.random.rand(m, k).astype("float32")) - b = tvm.nd.array(np.random.rand(n, k).astype("float32")) - c = tvm.nd.array(np.zeros((m, n), dtype="float32")) - c_np = np.dot(a.numpy(), b.numpy().transpose()) - mod(a, b, c) - tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5) - - -# pylint: disable=no-self-argument, missing-class-docstring, missing-function-docstring -@T.prim_func -def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, [128, 128]) - B = T.match_buffer(b, [128, 128]) - C = T.match_buffer(c, [128, 128]) - for i, j in T.grid(128, 128): - with T.block("init"): - vi, vj = T.axis.remap("SS", [i, j]) - C[vi, vj] = T.float32(0) - for k in range(128): - with T.block("update"): - vi, vj, vk = T.axis.remap("SSR", [i, j, k]) - C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] - - -@tvm.script.ir_module -class LoweredModule: - @T.prim_func - def main( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) - # body - for x, y in T.grid(128, 128): - C_flat[x * 128 + y] = 0.0 - for k in T.serial(0, 128): - C_flat[x * 128 + y] = ( - C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] - ) - - -@tvm.script.ir_module -class LoweredTIRModule: - @T.prim_func - def main( - A: T.Buffer((128, 128), "float32"), - B: T.Buffer((128, 128), "float32"), - C: T.Buffer((128, 128), "float32"), - ) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) - # body - for x, y in T.grid(128, 128): - C_flat[x * 128 + y] = 0.0 - for k in T.serial(0, 128): - C_flat[x * 128 + y] = ( - C_flat[x * 128 + y] + A_flat[x * 128 + k] * B_flat[y * 128 + k] - ) - - -def test_lower_build_tir_func(): - # check lowering with the CSE pass disabled as otherwise it would do some commoning - with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): - ir_mod = tvm.lower(matmul) - tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule) - # check building - mod = tvm.build(matmul, target="llvm") - _check_module_with_numpy(mod) - - -def test_lower_build_tir_module(): - func = matmul.with_attr("global_symbol", "main") - func = func.with_attr("tir.noalias", T.bool(True)) - ir_mod = IRModule({"main": func}) - # check lowering with the CSE pass disabled as otherwise it would do some commoning - with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): - lowered_mod = tvm.lower(ir_mod) - tvm.ir.assert_structural_equal(lowered_mod, LoweredTIRModule) - # check building - mod = tvm.build(ir_mod, target="llvm") - _check_module_with_numpy(mod) - - -def test_lower_build_lowered_module(): - # check lowering with the CSE pass disabled as otherwise it would do some commoning - with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]): - ir_mod = tvm.lower(LoweredTIRModule) - tvm.ir.assert_structural_equal(ir_mod, LoweredTIRModule) - # check building - mod = tvm.build(ir_mod, target="llvm") - _check_module_with_numpy(mod) - - -if __name__ == "__main__": - test_lower_build_te_schedule() - test_lower_build_tir_func() - test_lower_build_tir_module() - test_lower_build_lowered_module() diff --git a/tests/python/tir-base/test_tir_te_extern_primfunc.py b/tests/python/tir-base/test_tir_te_extern_primfunc.py index 45ca7a1c7256..16bc0b0ae2fc 100644 --- a/tests/python/tir-base/test_tir_te_extern_primfunc.py +++ b/tests/python/tir-base/test_tir_te_extern_primfunc.py @@ -192,7 +192,6 @@ def test_te_extern_call(self, func, params, verify): input_tensors = [te.placeholder(buf_name_map[name].shape) for name in params] output = te.extern_primfunc(input_tensors, prim_func) rt_prim_func = te.create_prim_func(tensors_from_extern_op(output, prim_func)) - tvm.ir.assert_structural_equal(tvm.lower(prim_func), tvm.lower(rt_prim_func)) target = tvm.target.Target("llvm") func = tvm.build(rt_prim_func, target=target) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py index 390745fe9d96..fe9998bc798e 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -117,7 +117,7 @@ def run_test( mma_store_intrin, ) - f = tvm.build(sch.mod["main"], target="cuda", name="dense") + f = tvm.build(sch.mod["main"], target="cuda") dev = tvm.device("cuda", 0) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py index 8077a603bcf2..2b3e6ce39bfb 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize_mfma_numeric.py @@ -109,7 +109,7 @@ def run_test( mma_store_intrin, ) - f = tvm.build(sch.mod["main"], target="rocm", name="dense") + f = tvm.build(sch.mod["main"], target="rocm") dev = tvm.device("rocm", 0) if in_dtype == "float32": diff --git a/tests/python/tir-transform/test_tir_transform_convert_ssa.py b/tests/python/tir-transform/test_tir_transform_convert_ssa.py index ec768ba74f7b..b93747c84a09 100644 --- a/tests/python/tir-transform/test_tir_transform_convert_ssa.py +++ b/tests/python/tir-transform/test_tir_transform_convert_ssa.py @@ -234,41 +234,6 @@ def func(A: T.Buffer(1, "float32")): assert before.same_as(after) -class TestDedupAutoBroadcastBuffer(BaseBeforeAfter): - """De-dup auto-broadcast buffers - - Auto-broadcast buffers can define additional variables during the - `Buffer::Buffer` constructor for the strides. This is intended to - be used for match buffers, where these variables are defined based - on the argument being passed in. - - These additional variables can cause errors when copying a buffer - with the `Buffer::Buffer` constructor. If a buffer has non-empty - shape, empty strides, and kAutoBroadcast type, then the resulting - buffer will have additional strides defined. Such a buffer can - result from lowering of a scalar buffer, which will be flattened - to a shape of [1]. - - Previous implementations of ConvertSSA incorrectly handled this - case, resulting in undefined stride variables. - """ - - def _make_func(self): - @T.prim_func - def func(a: T.handle): - A = T.match_buffer(a, shape=(), dtype="float32", buffer_type="auto") - A[()] = 1.0 - - return tvm.lower(func)["main"] - - def before(self): - func = self._make_func() - return tvm.IRModule({"func_a": func, "func_b": func}) - - def expected(self): - return tvm.IRModule({"func_a": self._make_func(), "func_b": self._make_func()}) - - class TestKeepDuplicateThreadIdxInSameFunction(BaseBeforeAfter): """Environment threads are treated as being at function scope diff --git a/tests/python/tir-transform/test_tir_transform_extract_constants.py b/tests/python/tir-transform/test_tir_transform_extract_constants.py index b3e0aa74f96d..cbfb6d39bcd2 100644 --- a/tests/python/tir-transform/test_tir_transform_extract_constants.py +++ b/tests/python/tir-transform/test_tir_transform_extract_constants.py @@ -63,8 +63,6 @@ def _visit(stmt): for n, f in mod.functions.items(): tvm.tir.stmt_functor.post_order_visit(f.body, _visit) - tvm.lower(mod) - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index b215398622cc..925f004cc527 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -322,36 +322,5 @@ def expected(): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) -def test_lower_2d_physical_memory(): - """Axis separators should preserve 2-d buffers through lowering. - - A catch-all test to ensure that defining axis_separators is - sufficient to maintain non-flat buffer descriptions through all - lowering steps. - """ - - # This test doesn't use CompareBeforeAfter, because the after step - # is not currently expressible in TVMScript. This test can be - # re-written after https://github.com/apache/tvm/pull/12412. - - @T.prim_func - def func(): - buf = T.alloc_buffer( - [1, 1], - dtype="int32", - scope="global", - axis_separators=[1], - ) - buf[0, 0] = 0 - - lowered = tvm.lower(func)["main"] - assert isinstance(lowered.body, tvm.tir.Allocate) - assert list(lowered.body.extents) == [1, 1], ( - "Non-flat buffer allocations, " - "marked by axis_separators, " - "flattened to flat memory allocation." - ) - - if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index 754ce032404d..0a040b0eeadb 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -179,7 +179,7 @@ def build_tir(): ) mod = build_tir() - f = tvm.build(mod, None, "llvm") + f = tvm.build(mod, None) a = tvm.nd.array(np.zeros(2, dtype="float32")) f(a) tvm.testing.assert_allclose(a.numpy(), expected_value) diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index 93c680c846c5..a7b528093967 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -111,33 +111,6 @@ def check(m, n, target_bits, target_dtype): check(2**14, 32, target_bits=16, target_dtype="int32") -def test_thread_axis_2(): - # fmt: off - @tvm.script.ir_module - class Before: - @T.prim_func - def main(T_reshape: T.Buffer((1, 12, 384, 384), "float32"), placeholder_1: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "bool"), T_where: T.Buffer((T.int64(1), T.int64(12), T.int64(384), 384), "float32")) -> None: - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - for i0_i1_i2_i3_fused_1 in T.thread_binding(T.int64(256), thread="blockIdx.x"): - for i0_i1_i2_i3_fused_2 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): - for i0_i1_i2_i3_fused_0 in T.serial(T.int64(7)): - with T.block("T_where"): - ax0 = T.axis.spatial(T.int64(1), T.int64(0)) - ax1 = T.axis.spatial(T.int64(12), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(1769472) // T.int64(147456)) - ax2 = T.axis.spatial(T.int64(384), ((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(147456) // T.int64(384)) - ax3 = T.axis.spatial(384, T.cast(((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2) % T.int64(384), "int32")) - T.where((i0_i1_i2_i3_fused_0 * T.int64(256) + i0_i1_i2_i3_fused_1) * T.int64(1024) + i0_i1_i2_i3_fused_2 < T.int64(1769472)) - T.reads(placeholder_1[ax0, ax1, ax2, ax3], T_reshape[ax0, ax1, ax2, ax3]) - T.writes(T_where[ax0, ax1, ax2, ax3]) - T_where[ax0, ax1, ax2, ax3] = T.Select(T.cast(placeholder_1[ax0, ax1, ax2, ax3], "int32") != 0, T.float32(-1000000000), T_reshape[ax0, ax1, ax2, ax3]) - # fmt: on - # TODO(@junrushao1994): make this test more "unit" after the new TVMScript printer/parser lands - tvm.lower(Before) - - def test_multilanes(): def check(m, lanes, target_bits, target_dtype): ib = tvm.tir.ir_builder.create() diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index ab91c6c7b330..548b199a94ce 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -610,87 +610,5 @@ def expected(A: T.Buffer(16, "float32"), D: T.Buffer(16, "float32")): D[i] = C[i] -def test_vulkan_smem_reuse(): - target = tvm.target.Target( - { - "keys": ["vulkan", "gpu"], - "kind": "vulkan", - "max_num_threads": 256, - "max_threads_per_block": 256, - "supports_float32": True, - "supports_int32": True, - "tag": "", - "thread_warp_size": 1, - } - ) - - @T.prim_func(private=True) - def func(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - A_shared = T.allocate([4], "float32", "shared") - A_local = T.allocate([4], "float32", "local") - B_shared = T.allocate([4], "float16", "shared") - A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_1 = T.Buffer((4,), data=A.data) - A_shared_1[threadIdx_x] = A_1[threadIdx_x] - A_local_1 = T.Buffer((4,), data=A_local, scope="local") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] - B_shared_1 = T.Buffer((4,), "float16", data=B_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - B_shared_1[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) - threadIdx_x = T.launch_thread("threadIdx.x", 4) - B_1 = T.Buffer((4,), "float16", data=B.data) - B_1[threadIdx_x] = B_shared_1[threadIdx_x] - - @T.prim_func(private=True) - def normal_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): - T.func_attr({"tir.noalias": T.bool(True)}) - A_shared = T.allocate([4], "float32", "shared") - A_local = T.allocate([4], "float32", "local") - A_shared_1 = T.Buffer((4,), data=A_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_1 = T.Buffer((4,), data=A.data) - A_shared_1[threadIdx_x] = A_1[threadIdx_x] - A_local_1 = T.Buffer((4,), data=A_local, scope="local") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_local_1[threadIdx_x] = A_shared_1[threadIdx_x] - A_shared_2 = T.Buffer((4,), "float16", data=A_shared, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_shared_2[threadIdx_x] = T.Cast("float16", A_local_1[threadIdx_x]) - threadIdx_x = T.launch_thread("threadIdx.x", 4) - B_1 = T.Buffer((4,), "float16", data=B.data) - B_1[threadIdx_x] = A_shared_2[threadIdx_x] - - @T.prim_func(private=True) - def no_reuse_lowering(A: T.Buffer((4,), "float32"), B: T.Buffer((4,), "float16")): - T.func_attr({"target": target, "tir.noalias": T.bool(True)}) - A_shared_1 = T.allocate([4], "float32", "shared") - A_local_1 = T.allocate([4], "float32", "local") - B_shared_1 = T.allocate([4], "float16", "shared") - A_shared_1_1 = T.Buffer((4,), data=A_shared_1, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_1 = T.Buffer((4,), data=A.data) - A_shared_1_1[threadIdx_x] = A_1[threadIdx_x] - A_local_1_1 = T.Buffer((4,), data=A_local_1, scope="local") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - A_local_1_1[threadIdx_x] = A_shared_1_1[threadIdx_x] - B_shared_1_1 = T.Buffer((4,), "float16", data=B_shared_1, scope="shared") - with T.launch_thread("threadIdx.x", 4) as threadIdx_x: - B_shared_1_1[threadIdx_x] = T.Cast("float16", A_local_1_1[threadIdx_x]) - threadIdx_x = T.launch_thread("threadIdx.x", 4) - B_1 = T.Buffer((4,), "float16", data=B.data) - B_1[threadIdx_x] = B_shared_1_1[threadIdx_x] - - # Reuse shared memory when lowering without target. - mod = tvm.IRModule({"main": func}) - tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], normal_lowering) - - # No shared memory reuse when lowering with target Vulkan. - mod = tvm.tir.transform.BindTarget(target)(mod) - tvm.ir.assert_structural_equal(tvm.lower(mod)["main"], no_reuse_lowering) - - if __name__ == "__main__": tvm.testing.main()