@@ -41,23 +41,13 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
4141 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
4242 uint32_t smem_int_mbar = smem_ptr_to_uint (&smem_mbar);
4343 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
44- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
45- asm volatile (" cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
46- " complete_tx::bytes"
47- " [%0], [%1, {%3}], [%2];"
48- :
49- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
50- " r" (crd0)
51- : " memory" );
52- } else {
53- asm volatile (" cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
54- " complete_tx::bytes.L2::cache_hint"
55- " [%0], [%1, {%3}], [%2], %4;"
56- :
57- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
58- " r" (crd0), " l" (cache_hint)
59- : " memory" );
60- }
44+ asm volatile (" cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::"
45+ " complete_tx::bytes.L2::cache_hint"
46+ " [%0], [%1, {%3}], [%2], %4;"
47+ :
48+ : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
49+ " r" (crd0), " l" (cache_hint)
50+ : " memory" );
6151}
6252
6353template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -67,23 +57,13 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
6757 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
6858 uint32_t smem_int_mbar = smem_ptr_to_uint (&smem_mbar);
6959 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
70- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
71- asm volatile (" cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
72- " complete_tx::bytes"
73- " [%0], [%1, {%3, %4}], [%2];"
74- :
75- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
76- " r" (crd0), " r" (crd1)
77- : " memory" );
78- } else {
79- asm volatile (" cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
80- " complete_tx::bytes.L2::cache_hint"
81- " [%0], [%1, {%3, %4}], [%2], %5;"
82- :
83- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
84- " r" (crd0), " r" (crd1), " l" (cache_hint)
85- : " memory" );
86- }
60+ asm volatile (" cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::"
61+ " complete_tx::bytes.L2::cache_hint"
62+ " [%0], [%1, {%3, %4}], [%2], %5;"
63+ :
64+ : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
65+ " r" (crd0), " r" (crd1), " l" (cache_hint)
66+ : " memory" );
8767}
8868
8969template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -93,23 +73,13 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
9373 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
9474 uint32_t smem_int_mbar = smem_ptr_to_uint (&smem_mbar);
9575 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
96- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
97- asm volatile (" cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
98- " complete_tx::bytes"
99- " [%0], [%1, {%3, %4, %5}], [%2];"
100- :
101- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
102- " r" (crd0), " r" (crd1), " r" (crd2)
103- : " memory" );
104- } else {
105- asm volatile (" cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
106- " complete_tx::bytes.L2::cache_hint"
107- " [%0], [%1, {%3, %4, %5}], [%2], %6;"
108- :
109- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
110- " r" (crd0), " r" (crd1), " r" (crd2), " l" (cache_hint)
111- : " memory" );
112- }
76+ asm volatile (" cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::"
77+ " complete_tx::bytes.L2::cache_hint"
78+ " [%0], [%1, {%3, %4, %5}], [%2], %6;"
79+ :
80+ : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
81+ " r" (crd0), " r" (crd1), " r" (crd2), " l" (cache_hint)
82+ : " memory" );
11383}
11484template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
11585TL_DEVICE void tma_load (const CUtensorMap &descriptor, uint64_t &smem_mbar,
@@ -119,23 +89,13 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
11989 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
12090 uint32_t smem_int_mbar = smem_ptr_to_uint (&smem_mbar);
12191 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
122- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
123- asm volatile (" cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
124- " complete_tx::bytes"
125- " [%0], [%1, {%3, %4, %5, %6}], [%2];"
126- :
127- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
128- " r" (crd0), " r" (crd1), " r" (crd2), " r" (crd3)
129- : " memory" );
130- } else {
131- asm volatile (" cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
132- " complete_tx::bytes.L2::cache_hint"
133- " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
134- :
135- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
136- " r" (crd0), " r" (crd1), " r" (crd2), " r" (crd3), " l" (cache_hint)
137- : " memory" );
138- }
92+ asm volatile (" cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::"
93+ " complete_tx::bytes.L2::cache_hint"
94+ " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
95+ :
96+ : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
97+ " r" (crd0), " r" (crd1), " r" (crd2), " r" (crd3), " l" (cache_hint)
98+ : " memory" );
13999}
140100
141101template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -146,24 +106,14 @@ TL_DEVICE void tma_load(const CUtensorMap &descriptor, uint64_t &smem_mbar,
146106 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
147107 uint32_t smem_int_mbar = smem_ptr_to_uint (&smem_mbar);
148108 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
149- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
150- asm volatile (" cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
151- " complete_tx::bytes"
152- " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
153- :
154- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
155- " r" (crd0), " r" (crd1), " r" (crd2), " r" (crd3), " r" (crd4)
156- : " memory" );
157- } else {
158- asm volatile (" cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
159- " complete_tx::bytes.L2::cache_hint"
160- " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
161- :
162- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
163- " r" (crd0), " r" (crd1), " r" (crd2), " r" (crd3), " r" (crd4),
164- " l" (cache_hint)
165- : " memory" );
166- }
109+ asm volatile (" cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::"
110+ " complete_tx::bytes.L2::cache_hint"
111+ " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
112+ :
113+ : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
114+ " r" (crd0), " r" (crd1), " r" (crd2), " r" (crd3), " r" (crd4),
115+ " l" (cache_hint)
116+ : " memory" );
167117}
168118
169119template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -176,49 +126,27 @@ TL_DEVICE void tma_load_im2col(const CUtensorMap &descriptor,
176126 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
177127 uint32_t smem_int_mbar = smem_ptr_to_uint (&smem_mbar);
178128 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
179- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
180- asm volatile (
181- " cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
182- " :complete_tx::bytes"
183- " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};"
184- :
185- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
186- " r" (coord_c), " r" (coord_w), " r" (coord_h), " r" (coord_n), " h" (offset_w),
187- " h" (offset_h)
188- : " memory" );
189- } else {
190- asm volatile (
191- " cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
192- " :complete_tx::bytes.L2::cache_hint"
193- " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;"
194- :
195- : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
196- " r" (coord_c), " r" (coord_w), " r" (coord_h), " r" (coord_n), " h" (offset_w),
197- " h" (offset_h), " l" (cache_hint)
198- : " memory" );
199- }
129+ asm volatile (" cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier:"
130+ " :complete_tx::bytes.L2::cache_hint"
131+ " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;"
132+ :
133+ : " r" (smem_int_ptr), " l" (gmem_int_desc), " r" (smem_int_mbar),
134+ " r" (coord_c), " r" (coord_w), " r" (coord_h), " r" (coord_n),
135+ " h" (offset_w), " h" (offset_h), " l" (cache_hint)
136+ : " memory" );
200137}
201138
202139template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
203140TL_DEVICE void tma_store (const CUtensorMap &descriptor,
204141 void const *const smem_ptr, int32_t const &crd0) {
205142 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
206143 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
207-
208- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
209- asm volatile (" cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, "
210- " {%2}], [%1];"
211- :
212- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0)
213- : " memory" );
214- } else {
215- asm volatile (" cp.async.bulk.tensor.1d.global.shared::cta.bulk_group "
216- " ::cache_hint [%0, {%2}], [%1], %3;"
217- :
218- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0),
219- " l" (cache_hint)
220- : " memory" );
221- }
144+ asm volatile (" cp.async.bulk.tensor.1d.global.shared::cta.bulk_group "
145+ " .L2::cache_hint [%0, {%2}], [%1], %3;"
146+ :
147+ : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0),
148+ " l" (cache_hint)
149+ : " memory" );
222150}
223151
224152template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -227,21 +155,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
227155 int32_t const &crd1) {
228156 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
229157 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
230-
231- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
232- asm volatile (" cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, "
233- " {%2, %3}], [%1];"
234- :
235- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1)
236- : " memory" );
237- } else {
238- asm volatile (" cp.async.bulk.tensor.2d.global.shared::cta.bulk_group "
239- " ::cache_hint [%0, {%2, %3}], [%1], %4;"
240- :
241- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
242- " l" (cache_hint)
243- : " memory" );
244- }
158+ asm volatile (" cp.async.bulk.tensor.2d.global.shared::cta.bulk_group "
159+ " .L2::cache_hint [%0, {%2, %3}], [%1], %4;"
160+ :
161+ : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
162+ " l" (cache_hint)
163+ : " memory" );
245164}
246165
247166template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -250,22 +169,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
250169 int32_t const &crd1, int32_t const &crd2) {
251170 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
252171 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
253-
254- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
255- asm volatile (" cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, "
256- " {%2, %3, %4}], [%1];"
257- :
258- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
259- " r" (crd2)
260- : " memory" );
261- } else {
262- asm volatile (" cp.async.bulk.tensor.3d.global.shared::cta.bulk_group "
263- " ::cache_hint [%0, {%2, %3, %4}], [%1], %5;"
264- :
265- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
266- " r" (crd2), " l" (cache_hint)
267- : " memory" );
268- }
172+ asm volatile (" cp.async.bulk.tensor.3d.global.shared::cta.bulk_group "
173+ " .L2::cache_hint [%0, {%2, %3, %4}], [%1], %5;"
174+ :
175+ : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
176+ " r" (crd2), " l" (cache_hint)
177+ : " memory" );
269178}
270179
271180template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -275,22 +184,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
275184 int32_t const &crd3) {
276185 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
277186 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
278-
279- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
280- asm volatile (" cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, "
281- " {%2, %3, %4, %5}], [%1];"
282- :
283- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
284- " r" (crd2), " r" (crd3)
285- : " memory" );
286- } else {
287- asm volatile (" cp.async.bulk.tensor.4d.global.shared::cta.bulk_group "
288- " ::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;"
289- :
290- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
291- " r" (crd2), " r" (crd3), " l" (cache_hint)
292- : " memory" );
293- }
187+ asm volatile (" cp.async.bulk.tensor.4d.global.shared::cta.bulk_group "
188+ " .L2::cache_hint [%0, {%2, %3, %4, %5}], [%1], %6;"
189+ :
190+ : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
191+ " r" (crd2), " r" (crd3), " l" (cache_hint)
192+ : " memory" );
294193}
295194
296195template <CacheHintSm90 cache_hint = CacheHintSm90::EVICT_NORMAL>
@@ -300,22 +199,12 @@ TL_DEVICE void tma_store(const CUtensorMap &descriptor,
300199 int32_t const &crd3, int32_t const &crd4) {
301200 uint64_t gmem_int_desc = reinterpret_cast <uint64_t >(&descriptor);
302201 uint32_t smem_int_ptr = smem_ptr_to_uint (smem_ptr);
303-
304- if constexpr (cache_hint == CacheHintSm90::EVICT_NORMAL) {
305- asm volatile (" cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, "
306- " {%2, %3, %4, %5, %6}], [%1];"
307- :
308- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
309- " r" (crd2), " r" (crd3), " r" (crd4)
310- : " memory" );
311- } else {
312- asm volatile (" cp.async.bulk.tensor.5d.global.shared::cta.bulk_group "
313- " ::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
314- :
315- : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
316- " r" (crd2), " r" (crd3), " r" (crd4), " l" (cache_hint)
317- : " memory" );
318- }
202+ asm volatile (" cp.async.bulk.tensor.5d.global.shared::cta.bulk_group "
203+ " .L2::cache_hint [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
204+ :
205+ : " l" (gmem_int_desc), " r" (smem_int_ptr), " r" (crd0), " r" (crd1),
206+ " r" (crd2), " r" (crd3), " r" (crd4), " l" (cache_hint)
207+ : " memory" );
319208}
320209
321210TL_DEVICE void prefetch_tma_descriptor (const CUtensorMap &descriptor) {
0 commit comments