Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions benchmark/matmul/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch

if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
topk = 10

carve_template = MatmulTemplate(
Expand Down
5 changes: 1 addition & 4 deletions benchmark/matmul/benchmark_matmul_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,7 @@ def get_configs(args, kwargs):
from tilelang.carver.roller.rasterization import NoRasterization
import torch

if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
topk = 10

carve_template = MatmulTemplate(
Expand Down
2 changes: 1 addition & 1 deletion docs/deeplearning_operators/gemv.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def splitk_gemv_vectorized(
return main
```

With vectorized read, now the kernel finishs in **~0.0084 ms**, which is getting close to cuBLAS performance.
With vectorized read, now the kernel finishes in **~0.0084 ms**, which is getting close to cuBLAS performance.


## `tvm_thread_allreduce` Instead of `atomicAdd`
Expand Down
6 changes: 2 additions & 4 deletions examples/analyze/example_conv_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tilelang.carver.arch import CDNA
from tilelang.layout import make_swizzled_layout
import torch

N = 64
C = 256
H = 512
Expand Down Expand Up @@ -95,10 +96,7 @@ def conv(

def main():
my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256)
if torch.version.hip is not None:
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
result = Analyzer.analysis(my_func, cuda_device)
print(result)
print(f"Analyzed FLOPs: {result.total_flops}")
Expand Down
5 changes: 1 addition & 4 deletions examples/analyze/example_gemm_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,7 @@ def matmul(
def main():
my_func = kernel(128, 128, 32, 3, 128, True)

if torch.version.hip is not None:
cuda_device=CDNA("hip")
else:
cuda_device = CUDA("cuda")
cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
result = Analyzer.analysis(my_func, cuda_device)

print(f"Analyzed FLOPs: {result.total_flops}")
Expand Down
2 changes: 1 addition & 1 deletion examples/bitnet-1.58b/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,7 +1373,7 @@ def prepare_inputs_for_generation(self,
cache_length + input_ids.shape[1] > max_cache_length):
attention_mask = attention_mask[:, -max_cache_length:]

position_ids = kwargs.get("position_ids", None)
position_ids = kwargs.get("position_ids")
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
Expand Down
5 changes: 1 addition & 4 deletions examples/gemm/example_gemm_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ def ref_program(A, B):

def get_configs(M, N, K, with_roller=False, topk=20):
if with_roller:
if torch.version.hip is not None:
arch=CDNA("hip")
else:
arch = CUDA("cuda")
arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip")
carve_template = MatmulTemplate(
M=M,
N=N,
Expand Down
2 changes: 1 addition & 1 deletion src/op/gemm_sp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ Stmt GemmSP::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
<< " and " << B.scope();
ICHECK((E.scope() == "shared" || E.scope() == "shared.dyn"))
<< "Only support shared.dyn scope for E as copy from smem to rmem are "
"delegated to cute implemntation, found "
"delegated to cute implementation, found "
<< E.scope();
ss << op_name << "<" << M << ", " << N << ", " << K << ", ";
ss << warp_m << ", " << warp_n << ", ";
Expand Down
2 changes: 1 addition & 1 deletion src/target/codegen_cpp.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class CodeGenTileLangCPP : public CodeGenC {
Array<String> function_names_;
/*! \brief whether to emit asserts in the resulting C code */
bool emit_asserts_;
/*! \brief whether to emit forwared function declarations in the resulting C
/*! \brief whether to emit forward function declarations in the resulting C
* code */
bool emit_fwd_func_decl_;

Expand Down
6 changes: 3 additions & 3 deletions src/target/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ CodeGenTileLangWebGPU::AddFunction(const PrimFunc &f, bool skip_readonly_decl) {
os_param_access << "]";
func_info.launch_param_tags.push_back(os_param_access.str());

ICHECK(!info.has_block_index_z)
<< "blockIdx.z is not supported in WebGPU to accomodate large blockIdx.x";
// anotate workgroup
ICHECK(!info.has_block_index_z) << "blockIdx.z is not supported in WebGPU to "
"accommodate large blockIdx.x";
// annotate workgroup
this->stream << "@compute @workgroup_size(" << info.workgroup_size[0] << ", "
<< info.workgroup_size[1] << ", " << info.workgroup_size[2]
<< ")\n";
Expand Down
26 changes: 13 additions & 13 deletions src/tl_templates/cpp/half.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@
#endif

#ifndef HALF_ENABLE_F16C_INTRINSICS
/// Enable F16C intruction set intrinsics.
/// Enable F16C instruction set intrinsics.
/// Defining this to 1 enables the use of [F16C compiler
/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between
/// half-precision and single-precision values which may result in improved
Expand Down Expand Up @@ -1674,7 +1674,7 @@ template <typename T> T half2float(unsigned int value) {
/// \tparam R rounding mode to use
/// \tparam E `true` for round to even, `false` for round away from zero
/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never
/// raise it \tparam T type to convert to (buitlin integer type with at least 16
/// raise it \tparam T type to convert to (builtin integer type with at least 16
/// bits precision, excluding any implicit sign bits) \param value
/// half-precision value to convert \return rounded integer value \exception
/// FE_INVALID if value is not representable in type \a T \exception FE_INEXACT
Expand Down Expand Up @@ -1778,7 +1778,7 @@ inline uint32 divide64(uint32 x, uint32 y, int &s) {
/// \tparam R `true` to compute signed remainder, `false` for positive remainder
/// \param x first operand as positive finite half-precision value
/// \param y second operand as positive finite half-precision value
/// \param quo adress to store quotient at, `nullptr` if \a Q `false`
/// \param quo address to store quotient at, `nullptr` if \a Q `false`
/// \return modulus of \a x / \a y
template <bool Q, bool R>
unsigned int mod(unsigned int x, unsigned int y, int *quo = NULL) {
Expand Down Expand Up @@ -2435,7 +2435,7 @@ template <typename, typename, std::float_round_style> struct half_caster;
/// Half-precision floating-point type.
/// This class implements an IEEE-conformant half-precision floating-point type
/// with the usual arithmetic operators and conversions. It is implicitly
/// convertible to single-precision floating-point, which makes artihmetic
/// convertible to single-precision floating-point, which makes arithmetic
/// expressions and functions with mixed-type operands to be of the most precise
/// operand type.
///
Expand All @@ -2445,9 +2445,9 @@ template <typename, typename, std::float_round_style> struct half_caster;
/// which means it can be standard-conformantly copied using raw binary copies.
/// But in this context some more words about the actual size of the type.
/// Although the half is representing an IEEE 16-bit type, it does not
/// neccessarily have to be of exactly 16-bits size. But on any reasonable
/// necessarily have to be of exactly 16-bits size. But on any reasonable
/// implementation the actual binary representation of this type will most
/// probably not ivolve any additional "magic" or padding beyond the simple
/// probably not involve any additional "magic" or padding beyond the simple
/// binary representation of the underlying 16-bit IEEE number, even if not
/// strictly guaranteed by the standard. But even then it only has an actual
/// size of 16 bits if your C++ implementation supports an unsigned integer type
Expand Down Expand Up @@ -2801,7 +2801,7 @@ template <> class numeric_limits<half_float::half> {
static HALF_CONSTEXPR_CONST bool traps = true;
#else
/// Traps only if [HALF_ERRHANDLING_THROW_...](\ref
/// HALF_ERRHANDLING_THROW_INVALID) is acitvated.
/// HALF_ERRHANDLING_THROW_INVALID) is activated.
static HALF_CONSTEXPR_CONST bool traps = false;
#endif

Expand Down Expand Up @@ -5067,7 +5067,7 @@ inline half frexp(half arg, int *exp) {
/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn).
/// \param arg number to modify
/// \param exp power of two to multiply with
/// \return \a arg multplied by 2 raised to \a exp
/// \return \a arg multiplied by 2 raised to \a exp
/// \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline half scalbln(half arg, long exp) {
Expand Down Expand Up @@ -5096,7 +5096,7 @@ inline half scalbln(half arg, long exp) {
/// **See also:** Documentation for
/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). \param
/// arg number to modify \param exp power of two to multiply with \return \a arg
/// multplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// multiplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline half scalbn(half arg, int exp) { return scalbln(arg, exp); }

Expand All @@ -5106,7 +5106,7 @@ inline half scalbn(half arg, int exp) { return scalbln(arg, exp); }
/// **See also:** Documentation for
/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). \param
/// arg number to modify \param exp power of two to multiply with \return \a arg
/// multplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// multiplied by 2 raised to \a exp \exception FE_INVALID for signaling NaN
/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding
inline half ldexp(half arg, int exp) { return scalbln(arg, exp); }

Expand Down Expand Up @@ -5379,7 +5379,7 @@ inline HALF_CONSTEXPR bool islessequal(half x, half y) {
!isnan(x) && !isnan(y);
}

/// Quiet comarison for less or greater.
/// Quiet comparison for less or greater.
/// **See also:** Documentation for
/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater).
/// \param x first operand
Expand Down Expand Up @@ -5503,7 +5503,7 @@ inline int feraiseexcept(int excepts) {
///
/// **See also:** Documentation for
/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// \param flagp adress to store flag state at
/// \param flagp address to store flag state at
/// \param excepts OR of flags to save
/// \retval 0 for success
inline int fegetexceptflag(int *flagp, int excepts) {
Expand All @@ -5520,7 +5520,7 @@ inline int fegetexceptflag(int *flagp, int excepts) {
///
/// **See also:** Documentation for
/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag).
/// \param flagp adress to take flag state from
/// \param flagp address to take flag state from
/// \param excepts OR of flags to restore
/// \retval 0 for success
inline int fesetexceptflag(const int *flagp, int excepts) {
Expand Down
2 changes: 1 addition & 1 deletion src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ using int4_t = int4;
} \
} while (0)

// abs function for bfloat_t and half_t since there is no implicit convertion
// abs function for bfloat_t and half_t since there is no implicit conversion
// method
TL_PATCH TL_DEVICE half_t __habs(const half_t x) {
return half_t(__habs(x.to_half()));
Expand Down
2 changes: 1 addition & 1 deletion src/tl_templates/cuda/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ debug_print_buffer_value<signed char>(const char *msg, const char *buf_name,
threadIdx.z, buf_name, index, var);
}

// Specialization for unsiged char type
// Specialization for unsigned char type
template <>
__device__ void
debug_print_buffer_value<unsigned char>(const char *msg, const char *buf_name,
Expand Down
2 changes: 1 addition & 1 deletion src/transform/atomicadd_vectorize.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* \file atomicadd_vectorize.cc
* \brief A tool to atomatically vectorize atomic add
* \brief A tool to automatically vectorize atomic add
*/

#include "../layout/layout.h"
Expand Down
8 changes: 4 additions & 4 deletions src/transform/merge_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ class SharedMemLinearAccessPatternFinder final : public StmtExprVisitor {
bool IsAppropriateSharedMemory(const Var &var) {
return is_dynamic_ ? IsDynamicSharedMemory(var) : IsStaticSharedMemory(var);
}
// Whether do dyanmic analysis.
// Whether do dynamic analysis.
bool is_dynamic_{true};
// Whether do aggressive merge.
bool enable_aggressive_merge_{false};
Expand Down Expand Up @@ -435,7 +435,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
const AllocateNode *alloc = shmem_allocs_[buffer];
auto alignment = align[i];
// Modern nvidia architecture performs hardware swizzling (hopper
// wgmma/tma for exmaple) requires dynamic shared memory address to
// wgmma/tma for example) requires dynamic shared memory address to
// be aligned to 1024 bytes For other devices, we align to 16 bytes
if (shmem_alignment_map_.find(buffer) !=
shmem_alignment_map_.end()) {
Expand Down Expand Up @@ -943,7 +943,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
*/
StorageEntry *NewAlloc(const AllocateNode *op, size_t const_nbits) {
ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer.
// Reuse not successful, allocate a new buffer.
StorageEntry *entry = arena_.make<StorageEntry>();
entry->allocs.push_back({op->buffer_var.get()});
entry->const_nbits = const_nbits;
Expand Down Expand Up @@ -1046,7 +1046,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
sym_free_list_.push_back(e);
}
}
// Wheather enable dyanmic analysis.
// Whether enable dynamic analysis.
bool is_dynamic_{true};

// Whether enable verbose logging.
Expand Down
18 changes: 9 additions & 9 deletions src/transform/storage_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,17 @@ class AllocateCollector : public StmtExprVisitor {
//
class LinearAccessPatternFinder final : public StmtExprVisitor {
public:
/*! \brief record the touch hist of statment. */
/*! \brief record the touch hist of statement. */
struct StmtEntry {
// The statment
// The statement
const Object *stmt;
// The index in the linear_seq_ to point to end of the nested scope.
// This is only set to non-zero if stmt is a nested scope.
// if offset > 0, means this is the begin, the end entry is current_index +
// offset if offset < 0, means this is the end, the begin entry is
// current_index + offset
int64_t scope_pair_offset{0};
// The buffer variables this statment touched.
// The buffer variables this statement touched.
std::vector<const VarNode *> touched;
};
// The scope of each allocation
Expand Down Expand Up @@ -675,7 +675,7 @@ class StoragePlanRewriter : public StmtExprMutator {
scope.tag != ".workspace" && scope.tag != ".vtcm";
}

// Alllocate entry of node.
// Allocate entry of node.
// Event entry in liveness analysis
struct EventEntry {
// variables we generate
Expand Down Expand Up @@ -785,10 +785,10 @@ class StoragePlanRewriter : public StmtExprMutator {
for (const AllocateNode *op : e->allocs) {
ICHECK_EQ(op->extents.size(), 1)
<< "Buffer var " << op->buffer_var->name_hint
<< " was identified as a re-usable allocation, but has "
<< " was identified as a reusable allocation, but has "
<< op->extents.size() << " physical dimensions. "
<< "Currently, only flat 1-d memory spaces should be "
"identified as re-usable "
"identified as reusable "
"allocations.";
PrimExpr sz = op->extents[0];
auto nbits = op->dtype.bits() * op->dtype.lanes();
Expand Down Expand Up @@ -905,7 +905,7 @@ class StoragePlanRewriter : public StmtExprMutator {
void PlanNewScope(const Object *op) {
if (thread_scope_ != nullptr) {
ICHECK(thread_scope_ == op);
// erase all memory atatched to this scope.
// erase all memory attached to this scope.
for (auto it = const_free_map_.begin(); it != const_free_map_.end();) {
if (it->second->attach_scope_ == op) {
it = const_free_map_.erase(it);
Expand Down Expand Up @@ -1023,7 +1023,7 @@ class StoragePlanRewriter : public StmtExprMutator {
StorageEntry *NewAlloc(const AllocateNode *op, const Object *attach_scope,
const StorageScope &scope, size_t const_nbits) {
ICHECK(op != nullptr);
// Re-use not successful, allocate a new buffer.
// Reuse not successful, allocate a new buffer.
auto entry = std::make_unique<StorageEntry>();
entry->attach_scope_ = attach_scope;
entry->scope = scope;
Expand All @@ -1050,7 +1050,7 @@ class StoragePlanRewriter : public StmtExprMutator {
// have its own allocation with size determined at runtime.
bool is_known_size = (const_nbits != 0);

// Currently, only flat memory spaces can be re-used. Packing
// Currently, only flat memory spaces can be reused. Packing
// into N-d space (e.g. 2-d texture memory on GPUs) will require
// more in-depth algorithms.
bool is_flat_memory_space = (num_physical_dimensions == 1);
Expand Down
2 changes: 1 addition & 1 deletion src/transform/thread_storage_sync.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
}
}
}
// return the exposed entries, remove unecessary ones.
// return the exposed entries, remove unnecessary ones.
int sync_count = 0;
// head are before first sync, tail are after last sync
std::vector<AccessEntry> head, tail;
Expand Down
6 changes: 3 additions & 3 deletions src/transform/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class TLVectorizer : public StmtMutator,
// A single var can be binded in multiple lets
// but they have to bind to the same value.
// This is used to allow cases when we reuse a single let
// expression to cosntruct a nested expr.
// expression to construct a nested expr.
// (let x = 1 in x + 1) * (let x = 1 in x + 1)
auto it = let_binding_.find(op->var);
if (it != let_binding_.end()) {
Expand Down Expand Up @@ -683,7 +683,7 @@ class TLVectorizer : public StmtMutator,
return StmtMutator::VisitStmt_(op);
}

// scalarize the statment
// scalarize the statement
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
stmt = Substitute(stmt, {{var_, idx}});
Expand All @@ -701,7 +701,7 @@ class TLVectorizer : public StmtMutator,
PrimExpr var_lanes_;
// ramp representing the var.
PrimExpr ramp_;
// flag to mark requirment of scalarization.
// flag to mark requirement of scalarization.
bool need_scalarize_{false};
// Let binding
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> let_binding_;
Expand Down
Loading
Loading