Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][Schedule] TileWithTensorIntrin skip incorrect ComputeInline for input-padding #16239

Merged

Conversation

XFPlus
Copy link
Contributor

@XFPlus XFPlus commented Dec 14, 2023

Previously, TileWithTensorIntrin tried to inline all producers into padding blocks once padding happened even if only one producer was padded.
This PR fixes that.

Possibly related issue: #15446

cc @tqchen @junrushao

@github-actions github-actions bot requested a review from tqchen December 14, 2023 10:02
@XFPlus XFPlus force-pushed the fix-tir-tile-with-tensor-intrin-when-padding branch from aae820e to 63b855a Compare December 14, 2023 10:08
@XFPlus
Copy link
Contributor Author

XFPlus commented Dec 14, 2023

@junrushao
Additionally, ComputeInline for consumers (in fact tensor_core_reindex_store) would conflict with ReverseComputeInline for tensor_core_reindex_store in AddWriteReuseTensorCore.
I checked previous version, inline for following block was done in other rule. Maybe we could make this rule more compatible by removing this computeinilne.

@spectrometerHBH
Copy link
Contributor

spectrometerHBH commented Dec 19, 2023

Would you mind giving an example of this, which can effectively be taken as a test case?

@XFPlus
Copy link
Contributor Author

XFPlus commented Dec 19, 2023

Would you mind giving an example of this, which can effectively be taken as a test case?
@spectrometerHBH

I got this issue when I tried to tune a conv2d with residual add. And there are two test cases for both producers' compute-inline and consumers' compute-inline. It would be better if TileWithTensorIntrin could cover various padding situations.

Detailly, when a block to be tiled has only one producer which produces padded input and the other one keeps the original input, we would see the first error:

Traceback (most recent call last):
  File "test_multilevel_tiling.py", line 95, in <module>
    test_wmma_producer_computinline()
  File "test_multilevel_tiling.py", line 84, in test_wmma_producer_computinline
    actual = _design_space(mod, "float16")
  File "test_multilevel_tiling.py", line 74, in _design_space
    return generate_design_space(
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/meta_schedule/testing/space_generation.py", line 68, in generate_design_space
    return ms.TuneContext(
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/meta_schedule/tune_context.py", line 167, in generate_design_space
    return self.space_generator.generate_design_space(self.mod)
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/meta_schedule/space_generator/space_generator.py", line 86, in generate_design_space
    return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod)  # type: ignore # pylint: disable=no-member
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.error.InternalError: Traceback (most recent call last):
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::tir::Schedule, void> (tvm::meta_schedule::SpaceGenerator, tvm::IRModule const&)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::SpaceGenerator, tvm::meta_schedule::SpaceGeneratorNode, tvm::runtime::Array<tvm::tir::Schedule, void>, tvm::IRModule const&, void>(tvm::runtime::Array<tvm::tir::Schedule, void> (tvm::meta_schedule::SpaceGeneratorNode::*)(tvm::IRModule const&))::{lambda(tvm::meta_schedule::SpaceGenerator, tvm::IRModule const&)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::SpaceGenerator, tvm::meta_schedule::SpaceGeneratorNode, tvm::runtime::Array<tvm::tir::Schedule, void>, tvm::IRModule const&, void>(tvm::runtime::Array<tvm::tir::Schedule, void> (tvm::meta_schedule::SpaceGeneratorNode::*)(tvm::IRModule const&))::{lambda(tvm::meta_schedule::SpaceGenerator, tvm::IRModule const&)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  5: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
  4: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
  3: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<tvm::meta_schedule::State, std::allocator<tvm::meta_schedule::State> >)
  2: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::TransformForTensorization(tvm::meta_schedule::TensorCoreState) const
  1: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(tvm::meta_schedule::TensorCoreStateNode*, tvm::runtime::String const&) const
  0: tvm::tir::TileWithTensorIntrin(tvm::tir::Schedule const&, tvm::tir::BlockRV const&, tvm::runtime::String const&, bool)
  File "/home/alex.wang/ws/project/tvm_0.14.0/src/tir/schedule/transform.cc", line 322
InternalError: Check failed: original_producers.size() == 1u (0 vs. 1) :

And, when the block is without padding for output, it will inline the following consumer wrongly. So we would got the second error:

Traceback (most recent call last):
  File "test_multilevel_tiling.py", line 96, in <module>
    test_wmma_consumer_computinline()
  File "test_multilevel_tiling.py", line 90, in test_wmma_consumer_computinline
    actual = _design_space(mod, "float16")
  File "test_multilevel_tiling.py", line 74, in _design_space
    return generate_design_space(
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/meta_schedule/testing/space_generation.py", line 68, in generate_design_space
    return ms.TuneContext(
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/meta_schedule/tune_context.py", line 167, in generate_design_space
    return self.space_generator.generate_design_space(self.mod)
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/meta_schedule/space_generator/space_generator.py", line 86, in generate_design_space
    return _ffi_api.SpaceGeneratorGenerateDesignSpace(self, mod)  # type: ignore # pylint: disable=no-member
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/home/alex.wang/ws/project/tvm_0.14.0/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
ValueError: Traceback (most recent call last):
  7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Array<tvm::tir::Schedule, void> (tvm::meta_schedule::SpaceGenerator, tvm::IRModule const&)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::meta_schedule::SpaceGenerator, tvm::meta_schedule::SpaceGeneratorNode, tvm::runtime::Array<tvm::tir::Schedule, void>, tvm::IRModule const&, void>(tvm::runtime::Array<tvm::tir::Schedule, void> (tvm::meta_schedule::SpaceGeneratorNode::*)(tvm::IRModule const&))::{lambda(tvm::meta_schedule::SpaceGenerator, tvm::IRModule const&)#1}>(tvm::runtime::Registry::set_body_method<tvm::meta_schedule::SpaceGenerator, tvm::meta_schedule::SpaceGeneratorNode, tvm::runtime::Array<tvm::tir::Schedule, void>, tvm::IRModule const&, void>(tvm::runtime::Array<tvm::tir::Schedule, void> (tvm::meta_schedule::SpaceGeneratorNode::*)(tvm::IRModule const&))::{lambda(tvm::meta_schedule::SpaceGenerator, tvm::IRModule const&)#1}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  6: tvm::meta_schedule::PostOrderApplyNode::GenerateDesignSpace(tvm::IRModule const&)
  5: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::Apply(tvm::tir::Schedule const&, tvm::tir::BlockRV const&)
  4: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::ApplySubRules(std::vector<tvm::meta_schedule::State, std::allocator<tvm::meta_schedule::State> >)
  3: tvm::meta_schedule::MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(tvm::meta_schedule::TensorCoreState) const
  2: tvm::tir::TracedScheduleNode::ReverseComputeInline(tvm::tir::BlockRV const&)
  1: tvm::tir::ConcreteScheduleNode::ReverseComputeInline(tvm::tir::BlockRV const&)
  0: tvm::tir::ConcreteScheduleNode::GetSRef(tvm::tir::BlockRV const&) const
  File "/home/alex.wang/ws/project/tvm_0.14.0/src/tir/schedule/./concrete_schedule.h", line 285
ValueError: The block no longer exists in the IRModule

The test cases are shown below.

import tempfile
import numpy as np

import tvm
from tvm import te
from tvm import meta_schedule as ms
from tvm._ffi import register_func
from tvm.meta_schedule.testing.space_generation import (
    check_sketches,
    generate_design_space,
)
from tvm.meta_schedule.builder import LocalBuilder
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import Trace

# get tensor intrin
from tvm.tir.tensor_intrin import cuda  # pylint: disable=unused-import

import tvm.testing


def matmul_fp16(N: int, M: int, K: int, out_dtype: str):
    x = te.placeholder((N, K), name="X", dtype="float16")
    y = te.placeholder((K, M), name="Y", dtype="float16")
    z = te.placeholder((N, M), name="Z", dtype=out_dtype)
    k = te.reduce_axis((0, K), name="k")
    c = te.compute(  # pylint: disable=invalid-name
        (N, M),
        lambda i, j: te.sum(x[i][k].astype(out_dtype) * y[k][j].astype(out_dtype), axis=[k]),
        name="C",
    )
    d = te.compute(
        (N, M),
        lambda i, j: c[i][j].astype(out_dtype) + z[i][j].astype(out_dtype),
        name='D'
    )
    return (x, y, z, d)


def multi_level_tiling_mma(out_dtype):
    simplify_dict = {"float32": "f32", "float16": "f16"}
    out_dtype = simplify_dict[out_dtype]
    return ms.schedule_rule.MultiLevelTilingTensorCore(
        intrin_groups=[
            {
                "init": f"wmma_fill_16x16x16_{out_dtype}",
                "load_a": "wmma_load_16x16x16_f16_a_shared_dyn",
                "load_b": "wmma_load_16x16x16_f16_b_shared_dyn",
                "compute": f"wmma_sync_16x16x16_f16f16{out_dtype}",
                "store": f"wmma_store_16x16x16_{out_dtype}_shared_dyn",
            },
        ],
        structure="SSSRRSRS",
        tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
        max_innermost_factor=4,  # 64 // tensor intrin size
        vector_load_lens=[1, 2, 3, 4, 8, 16],
        reuse_read=ms.schedule_rule.ReuseType(
            req="must",
            levels=[4],
            scope="shared.dyn",
        ),
        reuse_write=ms.schedule_rule.ReuseType(
            req="must",
            levels=[2],
            scope="shared.dyn",
        ),
        use_software_pipeline=True,
    )

def _design_space(mod, out_dtype):
    return generate_design_space(
        kind="cuda-tensorcore",
        mod=mod,
        target=Target("nvidia/geforce-rtx-3080"),
        types=None,
        sch_rules=[multi_level_tiling_mma(out_dtype)],
    )

def test_wmma_producer_computinline():
    mod = te.create_prim_func(matmul_fp16(M=1023, N=1024, K=4096, out_dtype="float16")) # only one matrix padded
    actual = _design_space(mod, "float16")
    for s in actual:
        print(s.mod)

def test_wmma_consumer_computinline():
    mod = te.create_prim_func(matmul_fp16(M=1024, N=1024, K=4095, out_dtype="float16")) # output matrix dosn't need padding
    actual = _design_space(mod, "float16")
    for s in actual:
        print(s.mod)

if __name__ == '__main__':
    test_wmma_producer_computinline()
    test_wmma_consumer_computinline()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants