Skip to content

Commit 0420742

Browse files
committed
lint
1 parent 396cca4 commit 0420742

File tree

7 files changed

+24
-33
lines changed

7 files changed

+24
-33
lines changed

examples/attention_sink/example_gqa_sink_bwd_bhsd.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def get_bwd_configs():
2020

2121

2222
@tilelang.jit(
23-
out_idx=[3, 4],
24-
pass_configs={
23+
out_idx=[3, 4], pass_configs={
2524
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2625
})
2726
def flashattn_fwd(
@@ -139,8 +138,7 @@ def flash_fwd(
139138

140139

141140
@tilelang.jit(
142-
out_idx=[2],
143-
pass_configs={
141+
out_idx=[2], pass_configs={
144142
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
145143
})
146144
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
@@ -178,8 +176,7 @@ def make_dq_layout(dQ):
178176

179177

180178
@tilelang.jit(
181-
out_idx=[1],
182-
pass_configs={
179+
out_idx=[1], pass_configs={
183180
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
184181
})
185182
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
@@ -202,10 +199,9 @@ def flash_bwd_post(
202199
return flash_bwd_post
203200

204201

205-
@tilelang.jit(
206-
pass_configs={
207-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
208-
})
202+
@tilelang.jit(pass_configs={
203+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
204+
})
209205
def flashattn_bwd(batch,
210206
heads,
211207
seq_len,

examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def get_configs():
2323
rep=100,
2424
)
2525
@tilelang.jit(
26-
out_idx=[3],
27-
pass_configs={
26+
out_idx=[3], pass_configs={
2827
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2928
})
3029
def flashattn(

examples/attention_sink/example_mha_sink_bwd_bhsd.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ def get_bwd_configs():
2020

2121

2222
@tilelang.jit(
23-
out_idx=[3, 4],
24-
pass_configs={
23+
out_idx=[3, 4], pass_configs={
2524
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2625
})
2726
def flashattn_fwd(
@@ -136,8 +135,7 @@ def flash_fwd(
136135

137136

138137
@tilelang.jit(
139-
out_idx=[2],
140-
pass_configs={
138+
out_idx=[2], pass_configs={
141139
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
142140
})
143141
def flashattn_bwd_preprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
@@ -175,8 +173,7 @@ def make_dq_layout(dQ):
175173

176174

177175
@tilelang.jit(
178-
out_idx=[1],
179-
pass_configs={
176+
out_idx=[1], pass_configs={
180177
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
181178
})
182179
def flashattn_bwd_postprocess(batch, heads, seq_len, dim, dtype: str = "float16"):
@@ -199,10 +196,9 @@ def flash_bwd_post(
199196
return flash_bwd_post
200197

201198

202-
@tilelang.jit(
203-
pass_configs={
204-
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
205-
})
199+
@tilelang.jit(pass_configs={
200+
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
201+
})
206202
def flashattn_bwd(
207203
batch,
208204
heads,

examples/attention_sink/example_mha_sink_fwd_bhsd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ def get_configs():
1818

1919
@autotune(configs=get_configs(), warmup=500, rep=100)
2020
@tilelang.jit(
21-
out_idx=[3],
22-
pass_configs={
21+
out_idx=[3], pass_configs={
2322
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2423
})
2524
def flashattn(

examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ def get_configs():
1919

2020
@autotune(configs=get_configs(), warmup=500, rep=100)
2121
@tilelang.jit(
22-
out_idx=[3],
23-
pass_configs={
22+
out_idx=[3], pass_configs={
2423
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
2524
})
2625
def flashattn(

src/target/codegen_cuda.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -948,17 +948,21 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
948948
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
949949
// bfloat162 -> float2
950950
PrintIndent();
951-
stream << sret << " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" << src << ")));\n";
951+
stream << sret
952+
<< " = __bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
953+
<< src << ")));\n";
952954
os << sret;
953955
return;
954956
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {
955957
// bfloat162x2 -> float4
956958
PrintIndent();
957959
stream << "((float2*)(&" << sret << "))[0] = "
958-
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&(" << src << ")));\n";
960+
<< "__bfloat1622float2(*reinterpret_cast<__nv_bfloat162*>(&("
961+
<< src << ")));\n";
959962
PrintIndent();
960963
stream << "((float2*)(&" << sret << "))[1] = "
961-
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&(" << src << "))+1));\n";
964+
<< "__bfloat1622float2(*(reinterpret_cast<__nv_bfloat162*>(&("
965+
<< src << "))+1));\n";
962966
os << sret;
963967
return;
964968
}
@@ -967,8 +971,8 @@ void CodeGenTileLangCUDA::VisitExpr_(const CastNode *op, std::ostream &os) {
967971
if (from_ty.lanes() == 2 && target_ty.lanes() == 2) {
968972
// float2 -> bfloat162
969973
PrintIndent();
970-
stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret << ")) = __float22bfloat162_rn(*(float2*)(&("
971-
<< src << ")));\n";
974+
stream << "*reinterpret_cast<__nv_bfloat162*>(&(" << sret
975+
<< ")) = __float22bfloat162_rn(*(float2*)(&(" << src << ")));\n";
972976
os << sret;
973977
return;
974978
} else if (from_ty.lanes() == 4 && target_ty.lanes() == 4) {

testing/python/language/test_tilelang_language_vectorized_cast.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ def run_vectorized_cast(src_dtype_str: str, dst_dtype_str: str, check_str: str,
3232
src_dtype_str: The source data type string.
3333
dst_dtype_str: The destination data type string.
3434
check_str: Used to ensure vectorized cast is used.
35-
M: The size of the tensor.
3635
lanes: The number of lanes of the source and destination data types.
3736
"""
3837

@@ -80,4 +79,3 @@ def test_vectorized_cast():
8079

8180
if __name__ == "__main__":
8281
tilelang.testing.main()
83-

0 commit comments

Comments
 (0)