Skip to content

Commit

Permalink
[Bug] Fix x86 dense schedule extern ops (apache#8420)
Browse files Browse the repository at this point in the history
* [Bug] Fix x86 dense schedule extern ops

* more

* lint
  • Loading branch information
comaniac authored and ylc committed Jan 13, 2022
1 parent 18bde30 commit c2517a8
Showing 1 changed file with 7 additions and 19 deletions.
26 changes: 7 additions & 19 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
from tvm.contrib import mkldnn

from .utils import get_fp32_len
from .injective import schedule_injective_from_existing
from .. import tag
from .. import generic, tag
from ..utils import traverse_inline, get_const_tuple


Expand Down Expand Up @@ -306,17 +305,6 @@ def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, tr
return C


def schedule_matmul_blas_common(outs):
"""Default matmul schedule for BLAS library"""
s = te.create_schedule([x.op for x in outs])
te.schedule.AutoInlineInjective(s)

for out in outs:
if "dense" not in out.op.tag and "matmul" not in out.op.tag:
schedule_injective_from_existing(s, out)
return s


@autotvm.register_topi_compute("dense_cblas.x86")
def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
"""Compute dense using cblas. This is an alias of matmul_nt operator."""
Expand All @@ -326,7 +314,7 @@ def dense_cblas(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule("dense_cblas.x86")
def schedule_dense_cblas(_, outs):
"""Create schedule for dense_cblas. This is an alias of matmul_nt operator."""
return schedule_matmul_blas_common(outs)
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("dense_mkl.x86")
Expand All @@ -338,7 +326,7 @@ def dense_mkl(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule("dense_mkl.x86")
def schedule_dense_mkl(_, outs):
"""Create schedule for dense_mkl. This is an alias of matmul_nt operator."""
return schedule_matmul_blas_common(outs)
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("dense_mkldnn.x86")
Expand All @@ -350,7 +338,7 @@ def dense_mkldnn(cfg, data, weight, bias=None, out_dtype=None):
@autotvm.register_topi_schedule("dense_mkldnn.x86")
def schedule_dense_mkldnn(_, outs):
"""Create schedule for dense_mkldnn. This is an alias of matmul_nt operator."""
return schedule_matmul_blas_common(outs)
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("matmul_cblas.x86")
Expand All @@ -366,7 +354,7 @@ def matmul_cblas(
@autotvm.register_topi_schedule("matmul_cblas.x86")
def schedule_matmul_cblas(_, outs):
"""Create schedule for matmul_cblas."""
return schedule_matmul_blas_common(outs)
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("matmul_mkl.x86")
Expand All @@ -382,7 +370,7 @@ def matmul_mkl(
@autotvm.register_topi_schedule("matmul_mkl.x86")
def schedule_matmul_mkl(_, outs):
"""Create schedule for matmul_mkl."""
return schedule_matmul_blas_common(outs)
return generic.schedule_extern(outs)


@autotvm.register_topi_compute("matmul_mkldnn.x86")
Expand All @@ -398,4 +386,4 @@ def matmul_mkldnn(
@autotvm.register_topi_schedule("matmul_mkldnn.x86")
def schedule_matmul_mkldnn(_, outs):
"""Create schedule for matmul_mkldnn."""
return schedule_matmul_blas_common(outs)
return generic.schedule_extern(outs)

0 comments on commit c2517a8

Please sign in to comment.