Skip to content

Commit

Permalink
implement swizzling template with mma
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Jan 23, 2024
1 parent ce36b97 commit 1beb6ae
Show file tree
Hide file tree
Showing 25 changed files with 861 additions and 152 deletions.
13 changes: 6 additions & 7 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,13 +455,12 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
}

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false, \
int64_t min_value = 0) { \
DataType dtype = DType; \
return expr.defined() ? tvm::cast(dtype, expr.value()) \
: (is_size_var ? tvm::tir::SizeVar("", dtype, min_value) \
: tvm::tir::Var("", dtype)); \
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false) { \
DataType dtype = DType; \
return expr.defined() \
? tvm::cast(dtype, expr.value()) \
: (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \
}

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
Expand Down
11 changes: 2 additions & 9 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,6 @@ class Var : public PrimExpr {
*/
class SizeVarNode : public VarNode {
public:
int64_t min_value;
void VisitAttrs(tvm::AttrVisitor* v) {
VarNode::VisitAttrs(v);
v->Visit("min_value", &min_value);
}

static constexpr const char* _type_key = "tir.SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};
Expand All @@ -163,15 +157,14 @@ class SizeVar : public Var {
* \param span The location of this object in the source code.
*/
TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
int64_t min_value = 0, Span span = Span());
Span span = Span());
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
* \param span The location of this object in the source code.
*/
TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, int64_t min_value = 0,
Span span = Span());
TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span());
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/dlight/base/roller/policy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,7 @@ def _score(node, thread): # small is better
break
# Plan vectorize
codegen_dict.vectorize = self._plan_vectorize(node, td, block_size)
codegen_dict.arch = self.arch
return codegen_dict

def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int):
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/dlight/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def _legalize_info(self):
self.use_async_copy = use_async_copy
else:
if self.arch.compute_capability == "sm_80":
self.use_async_copy = 2
else:
self.use_async_copy = 1
else:
self.use_async_copy = 0

def _compute_tc_strides(
self, node: PrimFuncNode, tile: List[int], rstep: Dict[str, int] = {}
Expand Down Expand Up @@ -299,6 +299,7 @@ def _score(node, thread): # small is better

codegen_dict.complete_config(node)
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
return codegen_dict

def plan_rasterization(self, td: TileDict):
Expand Down
16 changes: 12 additions & 4 deletions python/tvm/dlight/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def _apply_config(
if not reduction_blocks:
return dl.gpu.ElementWise().apply_config(func, config)
elif config.use_tc:
return dl.gpu.MatmulTensorization().apply_config(func, config)
if config.arch.sm_version >= 80:
# For A100(sm_80) or more advanced gpu, use MMA tensorization.
return dl.gpu.MatmulTensorizationMMA().apply_config(func, config)
else:
# For other GPUs, use WMMA tensorization.
return dl.gpu.MatmulTensorizationWMMA().apply_config(func, config)
else:
_reduction_rules = []

Expand Down Expand Up @@ -93,12 +98,15 @@ def _apply_and_build(
sch = _apply_config(func, config)
if sch is None:
return config, sch, None

# TODO(@lei): is tvm.build thread safe?
try:
with tvm.transform.PassContext(config={"tir.use_async_copy": True}):
with tvm.transform.PassContext(
config={"tir.use_async_copy": True, "tir.merge_static_smem": True}
):
mod = tvm.build(sch.mod["main"], target=arch.target)
except:
except Exception as e_msg:
print(e_msg)
mod = None
return config, sch, mod

Expand Down
7 changes: 2 additions & 5 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
auto_inline_producers,
get_in_out_dtypes,
get_index_map,
normalize_to_matmul,
get_reduction_blocks,
)
from .matmul_mma import MatmulTensorizationMMA
Expand Down Expand Up @@ -259,12 +260,8 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring
# analyzed by matmul expr.
assert len(config.block) == 2, "Matmul Only support 2D block"

if config.use_tc:
tensorize_sch = MatmulMMATensorization().apply_config(func, config)
if tensorize_sch is not None:
return tensorize_sch

main_block = reduction_blocks[0]

block_stmt = sch.get(main_block)

# cuda core prefer b is [k, j] layout without swizzling.
Expand Down
Loading

0 comments on commit 1beb6ae

Please sign in to comment.