Skip to content

Commit

Permalink
Merge branch 'unity-dev/2023-11-13-matmul-new' of https://github.com/…
Browse files Browse the repository at this point in the history
…Ubospica/tvm-develop into dev/fast_dlight
  • Loading branch information
LeiWang1999 committed Jan 22, 2024
2 parents 03915cf + b47b4fc commit ce36b97
Show file tree
Hide file tree
Showing 32 changed files with 4,042 additions and 2,077 deletions.
13 changes: 7 additions & 6 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,13 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
}

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false) { \
DataType dtype = DType; \
return expr.defined() \
? tvm::cast(dtype, expr.value()) \
: (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false, \
int64_t min_value = 0) { \
DataType dtype = DType; \
return expr.defined() ? tvm::cast(dtype, expr.value()) \
: (is_size_var ? tvm::tir::SizeVar("", dtype, min_value) \
: tvm::tir::Var("", dtype)); \
}

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
Expand Down
11 changes: 9 additions & 2 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class Var : public PrimExpr {
*/
class SizeVarNode : public VarNode {
public:
int64_t min_value;
void VisitAttrs(tvm::AttrVisitor* v) {
VarNode::VisitAttrs(v);
v->Visit("min_value", &min_value);
}

static constexpr const char* _type_key = "tir.SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};
Expand All @@ -157,14 +163,15 @@ class SizeVar : public Var {
* \param span The location of this object in the source code.
*/
TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
Span span = Span());
int64_t min_value = 0, Span span = Span());
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
* \param span The location of this object in the source code.
*/
TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span());
TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, int64_t min_value = 0,
Span span = Span());
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,17 @@ def canonical_simplify(self, expr):
"""
return self._canonical_simplify(expr)

def int_set(self, expr, dom_map):
def int_set(self, expr, dom_map=None):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.
Parameters
----------
expr : PrimExpr
The expression.
dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
dom_map : Optional[Dict[Var, tvm.arith.IntSet]]
The domain for variables to be relaxed. If None, use the domain map defined by bound
variables.
Returns
-------
Expand Down
13 changes: 8 additions & 5 deletions python/tvm/dlight/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@
GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
from .gemv import GEMV
from .fallback import Fallback
from .matmul import Matmul, MatmulWMMATensorization, MatmulMMATensorization
from .gemv import GEMV
from .general_reduction import GeneralReduction
from .matmul import (
Matmul,
MatmulTensorizationMMA,
MatmulTensorizationWMMA,
MatmulTensorizationLegacy,
)
from .reduction import Reduction
from .transpose import Transpose
from .general_reduction import GeneralReduction
from .element_wise import ElementWise
from .rmsnorm import RMSNorm
Loading

0 comments on commit ce36b97

Please sign in to comment.