@@ -107,7 +107,6 @@ def test_example_chunk_o_compilation():
107107
108108
109109def test_example_chunk_o_bwd_compilation ():
110- tilelang .disable_cache ()
111110 from example_chunk_o_bwd import tilelang_chunk_o_bwd_dqkwg , prepare_input
112111 Q , K , V , h , G , dO , dh , dv , W = prepare_input (B , S , H , DK , DV , chunk_size ,
113112 getattr (torch , input_dtype ),
@@ -118,13 +117,6 @@ def test_example_chunk_o_bwd_compilation():
118117 kernel = tilelang_chunk_o_bwd_dqkwg (B , S , H , DK , DV , input_dtype , output_dtype , accum_dtype ,
119118 gate_dtype , state_dtype , chunk_size , 1.0 , use_g , True ,
120119 block_DK , block_DV , threads , num_stages )
121- # print(kernel.get_kernel_source())
122- kernel = tilelang_chunk_o_bwd_dqkwg (B , S , H , DK , DV , input_dtype , output_dtype , accum_dtype ,
123- gate_dtype , state_dtype , chunk_size , 1.0 , use_g , True ,
124- block_DK , block_DV , threads , num_stages )
125- kernel = tilelang_chunk_o_bwd_dqkwg (B , S , H , DK , DV , input_dtype , output_dtype , accum_dtype ,
126- gate_dtype , state_dtype , chunk_size , 1.0 , use_g , True ,
127- block_DK , block_DV , threads , num_stages )
128120
129121 dq_tilelang , dk_tilelang , dw_tilelang , dg_tilelang = kernel (Q , K , V , h , G , dO , dh , dv ,
130122 W ) # noqa: F841
@@ -197,5 +189,4 @@ def test_example_chunk_delta_bwd_compilation():
197189
198190
199191if __name__ == "__main__" :
200- # tilelang.testing.main()
201- test_example_chunk_o_bwd_compilation ()
192+ tilelang .testing .main ()
0 commit comments