Skip to content

Commit 54ff485

Browse files
committed
Merge remote-tracking branch 'upstream/main' into update-maint-release
2 parents 703e63c + 7d96189 commit 54ff485

24 files changed

+244
-85
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ jobs:
5656
run: |
5757
"${{ steps.setup-pylowest.outputs.python-path }}" -m compileall -q -f tilelang
5858
59-
- name: Setup Python 3.12
59+
- name: Setup Python 3.9
6060
uses: actions/setup-python@v6
6161
with:
62-
python-version: "3.12"
62+
python-version: "3.9"
6363
update-environment: true
6464
cache: pip
6565
cache-dependency-path: |

examples/attention_sink/benchmark_gqa_sink_fwd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import triton.language as tl
66
from triton.tools.tensor_descriptor import TensorDescriptor
77
from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
8+
from typing import Optional
89

910

1011
@triton.jit
@@ -94,7 +95,7 @@ def triton_kernel(
9495
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
9596

9697

97-
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
98+
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
9899
bs, n_heads, seq_q, head_dim = Q.shape
99100
_, n_heads_kv, seq_kv, _ = K.shape
100101
BLOCK_M = 64
@@ -130,7 +131,7 @@ def main(
130131
seq_kv: int = 256,
131132
dim: int = 128,
132133
groups: int = 8,
133-
window_size: int | None = None,
134+
window_size: Optional[int] = None,
134135
dtype: str = "float16",
135136
tune: bool = False,
136137
):

examples/attention_sink/benchmark_mha_sink_fwd.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import triton.language as tl
66
from triton.tools.tensor_descriptor import TensorDescriptor
77
from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs
8+
from typing import Optional
89

910

1011
@triton.jit
@@ -93,7 +94,7 @@ def triton_kernel(
9394
Out.store([off_z, off_h, start_m * BLOCK_M, 0], acc)
9495

9596

96-
def triton_program(Q, K, V, Sinks, window_size: int | None = None) -> torch.Tensor:
97+
def triton_program(Q, K, V, Sinks, window_size: Optional[int] = None) -> torch.Tensor:
9798
bs, n_heads, seq_q, head_dim = Q.shape
9899
seq_kv = K.shape[2]
99100
BLOCK_M = 64
@@ -125,7 +126,7 @@ def main(batch: int = 1,
125126
seq_q: int = 256,
126127
seq_kv: int = 256,
127128
dim: int = 128,
128-
window_size: int | None = None,
129+
window_size: Optional[int] = None,
129130
dtype: str = "float16",
130131
tune: bool = False):
131132
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def main(BATCH: int = 1,
444444
N_CTX: int = 512,
445445
D_HEAD: int = 64,
446446
groups: int = 2,
447-
window_size: int | None = None,
447+
window_size: Optional[int] = None,
448448
dtype: str = "float16"):
449449
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
450450
if window_size is not None:

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def main(
272272
seq_kv: int = 256,
273273
dim: int = 128,
274274
groups: int = 8,
275-
window_size: int | None = None,
275+
window_size: Optional[int] = None,
276276
dtype: str = "float16",
277277
tune: bool = False,
278278
):

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ def main(BATCH: int = 1,
440440
H: int = 1,
441441
N_CTX: int = 512,
442442
D_HEAD: int = 128,
443-
window_size: int | None = None,
443+
window_size: Optional[int] = None,
444444
dtype: str = "float16"):
445445
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]
446446
if window_size is not None:

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def main(batch: int = 1,
253253
seq_q: int = 256,
254254
seq_kv: int = 256,
255255
dim: int = 128,
256-
window_size: int | None = None,
256+
window_size: Optional[int] = None,
257257
dtype: str = "float16",
258258
tune: bool = False):
259259
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def main(batch: int = 1,
263263
seq_q: int = 256,
264264
seq_kv: int = 256,
265265
dim: int = 128,
266-
window_size: int | None = None,
266+
window_size: Optional[int] = None,
267267
dtype: str = "float16",
268268
tune: bool = False):
269269
torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype]

pyproject.toml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
name = "tilelang"
33
description = "A tile level programming language to generate high performance code."
44
readme = "README.md"
5-
requires-python = ">=3.8"
5+
requires-python = ">=3.9"
66
authors = [{ name = "TileLang Contributors" }, { name = "Tile-AI" }]
77
maintainers = [{ name = "Lei Wang", email = "leiwang1999@outlook.com" }]
88
license = "MIT"
@@ -14,7 +14,6 @@ classifiers = [
1414
"Operating System :: MacOS",
1515
"Programming Language :: C++",
1616
"Programming Language :: Python :: 3",
17-
"Programming Language :: Python :: 3.8",
1817
"Programming Language :: Python :: 3.9",
1918
"Programming Language :: Python :: 3.10",
2019
"Programming Language :: Python :: 3.11",
@@ -128,7 +127,7 @@ skip = [
128127
]
129128

130129
[tool.ruff]
131-
target-version = "py38"
130+
target-version = "py39"
132131
line-length = 100
133132
output-format = "full"
134133

src/op/fill.cc

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "../transform/loop_partition.h"
1818
#include "../transform/loop_vectorize.h"
1919
#include "builtin.h"
20+
#include "region.h"
2021

2122
namespace tvm {
2223
namespace tl {
@@ -62,7 +63,30 @@ using namespace tir;
6263
Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
6364
ObjectPtr<FillNode> node = tvm::ffi::make_object<FillNode>();
6465

65-
if (args[0]->IsInstance<BufferLoadNode>()) {
66+
// Case 1: Region descriptor call (tl.region)
67+
if (const auto *call = args[0].as<CallNode>()) {
68+
if (call->op.same_as(RegionOp::Get())) {
69+
auto region = RegionOp(call->args, vmap);
70+
node->dst = region->GetBuffer();
71+
node->region = region->GetRanges();
72+
} else if (call->op.same_as(builtin::tvm_access_ptr())) {
73+
node->dst = vmap[GetVarFromAccessPtr(args[0])];
74+
for (int i = 0; i < node->dst->shape.size(); i++) {
75+
node->region.push_back(Range(0, node->dst->shape[i]));
76+
}
77+
} else {
78+
ICHECK(false) << "Unsupported call op in tl.fill: "
79+
<< Downcast<Op>(call->op)->name;
80+
}
81+
82+
// Case 2: Explicit BufferRegion (legacy path)
83+
} else if (args[0]->IsInstance<BufferRegionNode>()) {
84+
auto region = Downcast<BufferRegion>(args[0]);
85+
node->dst = region->buffer;
86+
node->region = region->region;
87+
88+
// Case 3: Vector/scalar region expressed via BufferLoad indices
89+
} else if (args[0]->IsInstance<BufferLoadNode>()) {
6690
auto buffer_load = Downcast<BufferLoad>(args[0]);
6791
for (const auto &index : buffer_load->indices) {
6892
if (const auto *ramp = index.as<RampNode>()) {
@@ -77,6 +101,7 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
77101
}
78102
}
79103
node->dst = buffer_load->buffer;
104+
// Case 4: Access pointer, fill the full buffer
80105
} else {
81106
node->dst = vmap[GetVarFromAccessPtr(args[0])];
82107
for (int i = 0; i < node->dst->shape.size(); i++) {
@@ -95,14 +120,19 @@ Fill::Fill(Array<PrimExpr> args, BufferMap vmap) {
95120
<< " != " << node->dst->shape.size();
96121
for (int i = 0; i < node->region.size(); i++) {
97122
// bound check if region is static
98-
if (node->region[i]->min.as<IntImm>()) {
99-
int64_t min = Downcast<IntImm>(node->region[i]->min)->value;
123+
if (const auto *min_imm = node->region[i]->min.as<IntImmNode>()) {
124+
int64_t min = min_imm->value;
100125
ICHECK_GE(min, 0) << "region[" << i << "] = " << min << " < 0";
101126
}
102-
if (node->region[i]->extent.as<IntImm>()) {
103-
int64_t extent = Downcast<IntImm>(node->region[i]->extent)->value;
104-
ICHECK_LE(extent, Downcast<IntImm>(node->dst->shape[i])->value)
105-
<< "region[" << i << "] = " << extent << " > " << node->dst->shape[i];
127+
if (const auto *extent_imm = node->region[i]->extent.as<IntImmNode>()) {
128+
// Only perform the upper-bound check when the destination shape
129+
// extent is also statically known. If the shape is symbolic (e.g., Var),
130+
// skip this static check to avoid invalid downcasts.
131+
if (const auto *shape_imm = node->dst->shape[i].as<IntImmNode>()) {
132+
ICHECK_LE(extent_imm->value, shape_imm->value)
133+
<< "region[" << i << "] = " << extent_imm->value << " > "
134+
<< node->dst->shape[i];
135+
}
106136
}
107137
}
108138
data_ = std::move(node);
@@ -140,7 +170,8 @@ For FillNode::MakeSIMTLoop(arith::Analyzer *analyzer) const {
140170
for (int i = 0; i < ndim; i++) {
141171
Var var = Var(std::string{char('i' + i)}, region[i]->extent->dtype);
142172
loop_vars.push_back({region[i], var, IterVarType::kDataPar});
143-
dst_indices.push_back(var);
173+
// Offset the loop induction variable by region min to honor sliced regions
174+
dst_indices.push_back(region[i]->min + var);
144175
}
145176
Stmt body = BufferStore(dst, value, dst_indices);
146177
for (int i = ndim - 1; i >= 0; i--) {
@@ -202,6 +233,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
202233
return vectorized_thread_loop;
203234
} else {
204235
LOG(FATAL) << "Unsupported scope " << dst.scope();
236+
return Stmt();
205237
}
206238
}
207239

@@ -229,4 +261,4 @@ TIR_REGISTER_TL_OP(Fill, fill)
229261
TVM_FFI_STATIC_INIT_BLOCK() { FillNode::RegisterReflection(); }
230262

231263
} // namespace tl
232-
} // namespace tvm
264+
} // namespace tvm

0 commit comments

Comments
 (0)