Skip to content

Commit a16f0cf

Browse files
authored
[Enhancement] Improve buffer conflict detection in thread storage synchronization (#658)
* [Enhancement] Improve buffer conflict detection in thread storage synchronization - Added a new boolean variable `range_is_overlap` to accurately determine if buffer indices overlap, enhancing the conflict detection logic in `thread_storage_sync.cc`. - Updated the return logic to reflect the overlap status, ensuring correct conflict resolution based on buffer index comparisons. - Removed an unnecessary comment in `OptimizeForTarget` to streamline the code and improve clarity. * example fix * enhancement * improve ci
1 parent c8edb95 commit a16f0cf

File tree

5 files changed

+49
-15
lines changed

5 files changed

+49
-15
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ jobs:
9494
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
9595
cd examples
9696
unset PYTHONPATH
97-
python -m pytest -n 4 **/test*.py
97+
python -m pytest -n 8 **/test*.py
9898
9999
- name: Run tests
100100
run: |
101101
source "${{ runner.tool_cache }}/${{ env.VENV_DIR }}/bin/activate"
102102
cd testing/python
103103
unset PYTHONPATH
104-
python -m pytest -n 4
104+
python -m pytest -n 8

examples/bitnet-1.58b/vllm_workspace/conftest.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,7 @@
2020
from vllm import LLM, SamplingParams
2121
from vllm.assets.image import ImageAsset
2222
from vllm.config import TokenizerPoolConfig
23-
from vllm.distributed import (
24-
destroy_distributed_environment,
25-
destroy_model_parallel,
26-
)
23+
from vllm.distributed import (destroy_distributed_environment, destroy_model_parallel)
2724
from vllm.inputs import TextPrompt
2825
from vllm.logger import init_logger
2926
from vllm.sequence import SampleLogprobs

examples/warp_specialize/example_warp_specialize_flashmla.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,9 @@ def flash_attn(
4949
scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
5050
scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
5151
scores_max = T.alloc_shared([block_H], accum_dtype)
52-
# TODO(lei): this is a workaround for the bug of replicate if stmt.
53-
# have to be optimized in future with index aware sync thread pass injection.
54-
# scores_max_prev_0 and scores_max_prev_1 should be allocated in fragment.
55-
scores_max_prev_0 = T.alloc_shared([block_H], accum_dtype)
56-
scores_max_prev_1 = T.alloc_shared([block_H], accum_dtype)
52+
53+
scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
54+
scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
5755

5856
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
5957
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
@@ -395,7 +393,7 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
395393
return out
396394

397395

398-
def main(batch=132, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
396+
def main(batch=1, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
399397
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
400398
pv_flops = 2 * batch * heads * kv_ctx * dim
401399
total_flops = qk_flops + pv_flops
@@ -414,7 +412,7 @@ def main(batch=132, heads=128, kv_heads=1, kv_ctx=8192, dim=512, pe_dim=64):
414412

415413
if __name__ == "__main__":
416414
parser = argparse.ArgumentParser()
417-
parser.add_argument('--batch', type=int, default=132, help='batch size')
415+
parser.add_argument('--batch', type=int, default=1, help='batch size')
418416
parser.add_argument('--heads', type=int, default=128, help='q heads number')
419417
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
420418
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')

src/transform/thread_storage_sync.cc

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
258258
// TODO(tqchen) more standard set based testing.
259259
bool has_same_index = true;
260260
bool range_is_equal = true;
261+
bool range_is_overlap = true;
262+
261263
for (const auto &kv : prev.thread_range) {
262264
if (!StructuralEqual()(kv.second, curr.thread_range[kv.first])) {
263265
range_is_equal = false;
@@ -275,6 +277,40 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
275277
const auto &curr_indice = curr.buffer_indices[i];
276278
if (!ExprDeepEqual()(prev_indice, curr_indice)) {
277279
has_same_index = false;
280+
281+
// If both are const, we can check if they are disjoint
282+
// by checking if the bounds are disjoint
283+
// [1024, 2048], [2048, 3072] are disjoint
284+
// [1024, 2048], [1024, 1024] are not disjoint
285+
auto prev_bound = analyzer_.const_int_bound(prev_indice);
286+
auto curr_bound = analyzer_.const_int_bound(curr_indice);
287+
if (prev_bound.defined() && curr_bound.defined()) {
288+
if (prev_bound->min_value > curr_bound->max_value ||
289+
curr_bound->min_value > prev_bound->max_value) {
290+
range_is_overlap = false;
291+
break;
292+
}
293+
}
294+
295+
// if we can prove prev_indice < curr_indice or prev_indice >
296+
// curr_indice, then they are not overlap
297+
auto prev_dtype = prev_indice.dtype();
298+
auto curr_dtype = curr_indice.dtype();
299+
if (prev_dtype.lanes() != curr_dtype.lanes()) {
300+
// can not support different lanes binary op like <, >, <=, >=
301+
// skip otherwise it will lead to error
302+
continue;
303+
}
304+
bool provably_disjoint =
305+
analyzer_.CanProve(prev_indice < curr_indice,
306+
arith::ProofStrength::kSymbolicBound) ||
307+
analyzer_.CanProve(prev_indice > curr_indice,
308+
arith::ProofStrength::kSymbolicBound);
309+
310+
if (provably_disjoint) {
311+
range_is_overlap = false;
312+
break;
313+
}
278314
}
279315

280316
if (!(has_same_index)) {
@@ -291,9 +327,13 @@ class TileLangThreadSyncPlanner : public TileLangStorageAccessVisitor {
291327
if (prev.double_buffer_write && curr.type == kRead && !loop_carry) {
292328
return false;
293329
}
330+
294331
// If nothing else allows sharing the same buffer, then they are
295332
// in conflict.
296-
return true;
333+
// if range_is_overlap is true, then they are in conflict, we should return
334+
// true. if range_is_overlap is false, then they are not in conflict, we
335+
// should return false.
336+
return range_is_overlap;
297337
}
298338

299339
void VisitStmt_(const AttrStmtNode *op) final {

tilelang/engine/phase.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule:
175175
mod)
176176
mod = tilelang.transform.ThreadSync("shared")(mod)
177177
mod = tilelang.transform.ThreadSync("shared.dyn")(mod)
178-
179178
# Inject PTX async copy must behind the thread sync pass
180179
# as ptx async copy won't be recognized as a valid buffer load
181180
mod = tilelang.transform.InjectPTXAsyncCopy()(mod)

0 commit comments

Comments
 (0)