Skip to content

Commit

Permalink
[SME] Introduce scalable fp32 dense schedule (#16921)
Browse files Browse the repository at this point in the history
This commit adds a new scalable fp32 dense schedule that calls SME intrinsics according to the SME RFC: apache/tvm-rfcs#107.

Currently the schedule does not make use of predication, meaning the output from the matmul compute must be copied in a subsequent compute stage. This will be removed once support for predication is added.
  • Loading branch information
lhutton1 authored May 15, 2024
1 parent cfe1711 commit b49468d
Show file tree
Hide file tree
Showing 24 changed files with 1,127 additions and 122 deletions.
10 changes: 10 additions & 0 deletions python/tvm/micro/testing/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@
},
)

AOT_APROFILE_AEM_RUNNER = AOTTestRunner(
makefile="aprofile_aem",
includes=[],
pass_config={
"tir.usmp.enable": False,
# AOT test infra generates 'fake' tensor inputs which fails asserts
"tir.disable_assert": True,
},
)


def parametrize_aot_options(test):
"""Parametrize over valid option combinations"""
Expand Down
69 changes: 68 additions & 1 deletion python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
import re

import tvm
from tvm import relay, topi, tir
from tvm.tir.schedule.analysis import has_block

from ....auto_scheduler import is_auto_scheduler_enabled
from ....meta_schedule import is_meta_schedule_enabled
Expand Down Expand Up @@ -639,7 +641,7 @@ def schedule_bitserial_dense_arm_cpu(attrs, inputs, out_type, target):
def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
"""dense arm cpu strategy"""
strategy = _op.OpStrategy()
data, _ = inputs
data, weight = inputs

if target.features.has_dsp and data.dtype in ["int8", "int16"]:
strategy.add_implementation(
Expand Down Expand Up @@ -680,6 +682,23 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
plevel=11,
)

if (
target.features.has_sme
and data.dtype in ["float32"]
and weight.dtype in ["float32"]
and out_type.dtype in ["float32"]
# The schedule uses tensorization which does not work when the
# reduction axis has unit iters. See
# https://github.com/apache/tvm/issues/16566
and data.shape[1] > 1
):
strategy.add_implementation(
wrap_compute_dense(topi.arm_cpu.compute_matmul_sme),
lambda: None,
name="matmul.arm_cpu.sme",
plevel=12,
)

# Fallback to x86 schedules as there is currently no arm_cpu schedule for dense
strategy.add_implementation(
wrap_compute_dense(topi.x86.dense_nopack),
Expand All @@ -697,6 +716,40 @@ def schedule_dense_arm_cpu(attrs, inputs, out_type, target):
return strategy


@matmul_strategy.register("arm_cpu")
def matmul_strategy_arm_cpu(attrs, inputs, out_type, target):
"""matmul arm cpu strategy"""
strategy = _op.OpStrategy()
data, weight = inputs

if (
target.features.has_sme
and data.dtype in ["float32"]
and weight.dtype in ["float32"]
and out_type.dtype in ["float32"]
and not (attrs.transpose_a or attrs.transpose_b)
and len(data.shape) == 2
# The schedule uses tensorization which does not work when the
# reduction axis has unit iters. See
# https://github.com/apache/tvm/issues/16566
and data.shape[1] > 1
):
# Ideally we should check that weight is a Relay constant, but strategy functions
# don't have access to the data needed to check this.
strategy.add_implementation(
wrap_compute_matmul(topi.arm_cpu.compute_matmul_sme),
lambda: None,
name="matmul.arm_cpu.sme",
)
return strategy

logger.warning("matmul is not optimized for arm cpu.")
strategy.add_implementation(
wrap_compute_matmul(topi.nn.matmul), naive_schedule, name="matmul.generic"
)
return strategy


@conv1d_strategy.register("arm_cpu")
def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv1d strategy"""
Expand Down Expand Up @@ -737,3 +790,17 @@ def conv1d_strategy_arm_cpu(attrs, inputs, out_type, target):
f"Unsupported kernel layout {kernel_layout} for conv1d {layout} for arm cpu."
)
return strategy


def arm_cpu_tir_strategy(sch: tir.Schedule) -> bool:
"""
Strategy for arm_cpu STIR schedules.
"""
current_target = tvm.target.Target.current()

if current_target.features.has_sme and has_block(sch, "matmul_sme_gemm"):
topi.arm_cpu.matmul.tir_schedule_matmul_sme(sch)
return True

# Fallback to TE schedule for operators we have not written a special TIR schedule for
return False
17 changes: 17 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,19 @@ def _corstone300_compile_time_check():
parent_features="cmsisnn",
)


def _aprofile_aem_fvp_compile_time_check():
if shutil.which("FVP_Base_RevC-2xAEMvA") is None:
return "AProfile AEM is not available"
return True


requires_aprofile_aem_fvp = Feature(
"aprofile-aem-fvp",
"AProfile AEM FVP",
compile_time_check=_aprofile_aem_fvp_compile_time_check,
)

# Mark a test as requiring Vitis AI to run
requires_vitis_ai = Feature("vitis_ai", "Vitis AI", cmake_flag="USE_VITIS_AI")

Expand Down Expand Up @@ -1205,6 +1218,10 @@ def decorator(*args):
return decorator


def skip_if_no_reference_system(func):
return skip_if_32bit(reason="Reference system unavailable in i386 container")(func)


def requires_package(*packages):
"""Mark a test as requiring python packages to run.
Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/tensor_intrin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@
# under the License.
# pylint: disable=unused-import
"""Intrinsics for tensorization."""
from . import arm_cpu, cuda, rocm, x86, hexagon
Loading

0 comments on commit b49468d

Please sign in to comment.