Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Tensorization Failure During Multilevel Tiling with Tensor Intrin #16614

Closed
zxybazh opened this issue Feb 21, 2024 · 5 comments
Closed
Labels
tune:meta_schedule src/meta_schedule, python/tvm/meta_schedule type: bug

Comments

@zxybazh
Copy link
Member

zxybazh commented Feb 21, 2024

Expected behavior

MetaSchedule Tuning Works for the given Conv2d workload

Actual behavior

Triggers an error ValueError: The block no longer exists in the IRModule during application of schedule rule Multi-level tiling with tensor intrin. I notcied that state->tensor_core_reindex_store would point to a block that is already merged into another block via ComputeInline during application of TileWithTensorIntrin.

Environment

Latest TVM Main

Steps to reproduce

import tvm
from tvm.script import tir as T
from tvm import meta_schedule as ms

@T.prim_func(private=True)
def fused_conv2d_add1(reshape3: T.Buffer((T.int64(50), T.int64(8), T.int64(72), T.int64(128)), "float16"), conv_in_weight: T.Buffer((T.int64(320), T.int64(8), T.int64(3), T.int64(3)), "float16"), lv23: T.Buffer((T.int64(1), T.int64(320), T.int64(1), T.int64(1)), "float16"), T_add_intermediate: T.Buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")):
    T.func_attr({"tir.noalias": T.bool(True)})
    # with T.block("root"):
    pad_temp = T.alloc_buffer((T.int64(50), T.int64(8), T.int64(74), T.int64(130)), "float16")
    conv2d_nchw_intermediate = T.alloc_buffer((T.int64(50), T.int64(320), T.int64(72), T.int64(128)), "float16")
    for i0, i1, i2, i3 in T.grid(T.int64(50), T.int64(8), T.int64(74), T.int64(130)):
        with T.block("pad_temp"):
            v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
            T.reads(reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)])
            T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
            pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(73) and T.int64(1) <= v_i3 and v_i3 < T.int64(129), reshape3[v_i0, v_i1, v_i2 - T.int64(1), v_i3 - T.int64(1)], T.float16(0))
    for nn, ff, yy, xx, rc, ry, rx in T.grid(T.int64(50), T.int64(320), T.int64(72), T.int64(128), T.int64(8), T.int64(3), T.int64(3)):
        with T.block("conv2d_nchw"):
            v_nn, v_ff, v_yy, v_xx, v_rc, v_ry, v_rx = T.axis.remap("SSSSRRR", [nn, ff, yy, xx, rc, ry, rx])
            T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx], conv_in_weight[v_ff, v_rc, v_ry, v_rx])
            T.writes(conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx])
            with T.init():
                conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = T.float16(0)
            conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] = conv2d_nchw_intermediate[v_nn, v_ff, v_yy, v_xx] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx] * conv_in_weight[v_ff, v_rc, v_ry, v_rx]
    for ax0, ax1, ax2, ax3 in T.grid(T.int64(50), T.int64(320), T.int64(72), T.int64(128)):
        with T.block("T_add"):
            v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
            T.reads(conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)])
            T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3])
            T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = conv2d_nchw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] + lv23[T.int64(0), v_ax1, T.int64(0), T.int64(0)]

target=tvm.target.Target("nvidia/nvidia-a10g")
func = fused_conv2d_add1
ms.tune_tir(func, target=target, max_trials_global=100, work_dir="./temp")
@zxybazh zxybazh added type: bug tune:meta_schedule src/meta_schedule, python/tvm/meta_schedule labels Feb 21, 2024
@krishnab30
Copy link

krishnab30 commented Mar 26, 2024

Hi @zxybazh , I am facing the same issue, when I try to metaschedule a resnet-50 relay workload.

@zxybazh
Copy link
Member Author

zxybazh commented Sep 13, 2024

Won't fix for now. Please feel free to reopen if this reproduces.

@zxybazh zxybazh closed this as completed Sep 13, 2024
@Yongqi-Zhuo
Copy link
Contributor

I don't know why this issue was closed, because it reproduces for me when I was tuning for quantized ResNet-18.

After some search I found that #16239 and #15505 may be related to this bug. I tried the workaround in #15505 and although AddWriteReuseTensorCore no longer throws ValueError: The block no longer exists in the IRModule, now AddReadReuseTensorCore throws ScheduleError: An error occurred in the schedule primitive 'compute-at': The scope tir.Block#0 is not a stage pipeline. Unfortunately I have no idea where to start debugging.

@zxybazh Could you please look into this a little more? The comments in #16239 might help. Also could you see any unsoundness in the workaround in #15505?
@XFPlus Have you finally found a workaround for this issue? Also have you tried #15505?

Thank you for your help.

@zxybazh
Copy link
Member Author

zxybazh commented Sep 29, 2024

Thanks for reporting, I think the guard proposed in #15505 could be a temporary fix but a more principled solution would be to fix tensorization rule. CC @vinx13

@Yongqi-Zhuo
Copy link
Contributor

After some digging I discovered the root cause of this bug: TileWithTensorIntrin in src/tir/schedule/transform.cc calls PadEinsum and inlines the padded input (output) to its original producer (consumer). However it is possible that one of the inputs/outputs does not need to be padded, in which case that producer (consumer) is not padded by PadEinsum. This means that TileWithTensorIntrin may inline blocks that are irrelevant to padding and must not be inlined.

I found that #17171 is also related to this bug, but that PR along with #16239 and #15505 did not really solve this problem.

I have fixed this by tracking which producers (consumers) are padded by PadEinsum. I will file a PR later, and hopefully it will be able to replace #17171 and #16239.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tune:meta_schedule src/meta_schedule, python/tvm/meta_schedule type: bug
Projects
None yet
Development

No branches or pull requests

3 participants