-
Notifications
You must be signed in to change notification settings - Fork 44
/
test_core.py
6277 lines (5239 loc) · 244 KB
/
test_core.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# flake8: noqa: F821,F841
import contextlib
import itertools
import re
from typing import Optional
import math
import textwrap
import pathlib
import numpy as np
import pytest
import torch
import os
import inspect
from numpy.random import RandomState
import triton
import triton.language as tl
from triton.language.extra import libdevice
from triton._internal_testing import (
integral_dtypes,
int_dtypes,
uint_dtypes,
float_dtypes,
float_dtypes_with_bfloat16,
dtypes,
dtypes_with_bfloat16,
is_cuda,
is_interpreter,
is_hopper,
is_hip,
is_hip_cdna,
is_hip_mi200,
is_hip_mi300,
is_xpu,
get_arch,
torch_float8_dtypes,
torch_dtypes,
numpy_random,
to_triton,
torch_dtype_name,
to_numpy,
)
@contextlib.contextmanager
def promotion_numpy_2_0():
state = np._get_promotion_state()
np._set_promotion_state("weak")
try:
yield
finally:
np._set_promotion_state(state)
def xpu_has_fp64():
assert is_xpu()
target = triton.runtime.driver.active.get_current_target()
return target.arch['has_fp64']
# TODO: enable multiple cta cluster testing.
# num_ctas_list = [1, 4] if torch.cuda.get_device_capability()[0] == 9 else [1]
num_ctas_list = [1]
GPU_DIALECT = "ttg"
if is_interpreter():
THREADS_PER_WARP = 1
elif is_hip():
THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size
else:
THREADS_PER_WARP = 32
def _bitwidth(dtype: str) -> int:
# ex.: "int64" -> 64
return int(re.search(r'(\d+)$', dtype).group(1))
def _dtype(dtype: str) -> str:
# ex.: "int64" -> "int"
return re.match(r'([a-zA-Z]+)', dtype).group(0)
def patch_kernel(template, to_replace):
if is_interpreter():
local_namespace = {}
src = textwrap.dedent(inspect.getsource(template.fn))
for k, v in to_replace.items():
src = src.replace(k, v)
exec(src, globals(), local_namespace)
return local_namespace[template.fn.__name__]
else:
kernel = triton.JITFunction(template.fn)
for key, value in to_replace.items():
kernel.src = kernel.src.replace(key, value)
return kernel
def check_cuda_or_hip(device):
# CUDA and HIP both use pytorch device 'cuda'. Other backends like Intel
# GPU do not.
if device not in ['cuda']:
pytest.xfail("Only for cuda or HIP")
def check_type_supported(dtype, device):
'''
skip test if dtype is not supported on the current device
'''
if device in ['cuda']:
cc = torch.cuda.get_device_capability()
if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16):
pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80")
if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}:
pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90")
if is_interpreter():
if dtype in [tl.bfloat16, "bfloat16", torch.bfloat16]:
pytest.xfail("bfloat16 is not supported in the interpreter")
elif device in ['xpu']:
if dtype in [torch.float64, "float64"] and not xpu_has_fp64():
pytest.xfail("float64 not supported on current xpu hardware")
class MfmaLayout:
def __init__(self, version, warps_per_cta, instr_shape, is_transposed):
self.version = version
self.warps_per_cta = warps_per_cta
self.instr_shape = instr_shape
self.is_transposed = is_transposed
def __str__(self):
return f"#{GPU_DIALECT}.amd_mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {str(self.is_transposed).lower()}}}>"
class WmmaLayout:
def __init__(self, version, warps_per_cta):
self.version = version
self.warps_per_cta = warps_per_cta
def __str__(self):
return f"#{GPU_DIALECT}.amd_wmma<{{version = {self.version}, warpsPerCTA = {self.warps_per_cta}}}>"
class MmaLayout:
def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape):
self.version = version
self.warps_per_cta = warps_per_cta
self.ctas_per_cga = ctas_per_cga
self.cta_split_num = cta_split_num
self.cta_order = cta_order
self.instr_shape = instr_shape
def __str__(self):
return f"#{GPU_DIALECT}.nvidia_mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
class DpasLayout:
def __init__(self, repeatCount, systolic_depth, execution_size, ops_per_chan, threads_per_warp, warps_per_cta,
rep_cluster):
self.repeatCount = repeatCount
self.systolic_depth = systolic_depth
self.execution_size = execution_size
self.ops_per_chan = ops_per_chan
self.threads_per_warp = threads_per_warp
self.warps_per_cta = warps_per_cta
self.rep_cluster = rep_cluster
def __str__(self):
return f"#triton_intel_gpu.dpas<{{repeatCount={self.repeatCount}, systolicDepth={self.systolic_depth}, executionSize = {self.execution_size}, opsPerChan = {self.ops_per_chan}, threadsPerWarp = {self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, repCluster={self.rep_cluster}}}>"
class DotOperandLayout:
def __init__(self, parent, op_idx, k_width):
self.parent = parent
self.op_idx = op_idx
self.k_width = k_width
def __str__(self):
return f"#{GPU_DIALECT}.dot_op<{{parent={self.parent}, opIdx={self.op_idx}, kWidth={self.k_width}}}>"
class BlockedLayout:
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order):
self.sz_per_thread = size_per_thread
self.threads_per_warp = threads_per_warp
self.warps_per_cta = warps_per_cta
self.order = order
self.ctas_per_cga = ctas_per_cga
self.cta_split_num = cta_split_num
self.cta_order = cta_order
def __str__(self):
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
class SharedLayout:
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
self.vec = vec
self.per_phase = per_phase
self.max_phase = max_phase
self.order = order
self.ctas_per_cga = ctas_per_cga
self.cta_split_num = cta_split_num
self.cta_order = cta_order
def __str__(self):
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
def is_layout_applicable(layout) -> bool:
common_layouts = [BlockedLayout, SharedLayout]
if layout in common_layouts:
return True
elif is_cuda():
mma_layout = layout.parent if isinstance(layout, DotOperandLayout) else layout
if not isinstance(mma_layout, MmaLayout):
return False
if mma_layout.version[0] >= 3 and not is_hopper():
return False
return True
elif is_hip():
target_arch = triton.runtime.driver.active.get_current_target().arch
if "gfx11" in target_arch:
# RDNA 3
return isinstance(layout, WmmaLayout)
elif any(arch for arch in ["gfx8", "gfx9"] if arch in target_arch):
# CDNA 1, 2, 3
return isinstance(layout, MfmaLayout)
else:
return False
else:
return True
def filter_layouts(layouts):
return [l for l in layouts if is_layout_applicable(l)]
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x", list(dtypes) + ["bfloat16"])
def test_empty_kernel(dtype_x, device):
SIZE = 128
@triton.jit
def kernel(X, SIZE: tl.constexpr):
pass
check_type_supported(dtype_x, device)
x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x)
kernel[(1, )](x, SIZE=SIZE, num_warps=4)
# generic test functions
def _test_unary(dtype_x, expr, numpy_expr=None, device='cuda', num_ctas=1):
check_type_supported(dtype_x, device) # early return if dtype_x is not supported
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': expr})
# inputs
x = numpy_random(SIZE, dtype_str=dtype_x)
if 'log' in expr:
x = np.abs(x) + 0.01
# reference result
z_ref = eval(expr if numpy_expr is None else numpy_expr)
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_x)
kernel[(1, )](Z=z_tri, X=x_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas)
# compare
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01)
def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]:
"""
Given two dtype strings, returns the numpy dtype Triton thinks binary
operations on the two types should return. Returns None if the return value
matches numpy. This is generally needed because Triton and pytorch return
narrower floating point types than numpy in mixed operations, and because
Triton follows C/C++ semantics around mixed signed/unsigned operations, and
numpy/pytorch do not.
"""
overrides = {
('float16', 'int16'): np.float16,
('float16', 'int32'): np.float16,
('float16', 'int64'): np.float16,
('float16', 'uint16'): np.float16,
('float16', 'uint32'): np.float16,
('float16', 'uint64'): np.float16,
('int8', 'uint8'): np.uint8,
('int8', 'uint16'): np.uint16,
('int8', 'uint32'): np.uint32,
('int8', 'uint64'): np.uint64,
('int16', 'uint16'): np.uint16,
('int16', 'uint32'): np.uint32,
('int16', 'uint64'): np.uint64,
('int32', 'uint32'): np.uint32,
('int32', 'uint64'): np.uint64,
('int64', 'uint64'): np.uint64,
}
key = (a, b) if a < b else (b, a)
return overrides.get(key)
def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1,
x_low=None, x_high=None, y_low=None, y_high=None, filter_y=None, test_broadcast=True,
test_scalar=True):
check_type_supported(dtype_x, device) # early return if dtype_x is not supported
check_type_supported(dtype_y, device)
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
@triton.jit
def kernel_broadcast_lhs(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X)
y = tl.load(Y + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
@triton.jit
def kernel_broadcast_rhs(Z, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
@triton.jit
def kernel_scalar_rhs(Z, X, y: tl.constexpr, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
z = GENERATE_TEST_HERE
tl.store(Z + off, z)
replacements = {'GENERATE_TEST_HERE': expr}
kernel = patch_kernel(kernel, replacements)
kernel_broadcast_lhs = patch_kernel(kernel_broadcast_lhs, replacements)
kernel_broadcast_rhs = patch_kernel(kernel_broadcast_rhs, replacements)
kernel_scalar_rhs = patch_kernel(kernel_scalar_rhs, replacements)
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs, low=x_low, high=x_high)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs, low=y_low, high=y_high)
if filter_y:
y[filter_y(y)] = 1
if mode_x == 'nan':
x[:] = float('nan')
if mode_y == 'nan':
y[:] = float('nan')
def do_test(x, y, kernel_fn):
x_is_scalar = isinstance(x, (bool, int, float))
y_is_scalar = isinstance(y, (bool, int, float))
scalar_test = x_is_scalar or y_is_scalar
# For scalars, we follow the NumPy 2.0 (and JAX/PyTorch pretty much) casting rules.
if scalar_test:
# We remove any explicit casting
pattern = r'\.astype\(np\.\w+\)'
scalar_expr = expr if numpy_expr is None else re.sub(pattern, '', numpy_expr)
with promotion_numpy_2_0():
z_ref = eval(scalar_expr)
else:
z_ref = eval(expr if numpy_expr is None else numpy_expr)
dtype_z = _binary_op_dtype_override(dtype_x, dtype_y)
if not scalar_test and dtype_z is not None:
z_ref = z_ref.astype(dtype_z)
# triton result
x_tri = x if x_is_scalar else to_triton(x, device=device, dst_type=dtype_x)
y_tri = y if y_is_scalar else to_triton(y, device=device, dst_type=dtype_y)
if is_xpu() and not xpu_has_fp64() and z_ref.dtype in ["float64"]:
# Downcast the output type. Assumes similar overflow behavior to reference eval on the device.
z_ref = z_ref.astype("float32")
z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device)
kernel_fn[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas)
err_msg = f"{expr}, {kernel_fn.__name__}"
np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=err_msg, atol=7e-3, rtol=0.01)
def get_scalar(x, dtype, low, high, filter):
# If dtype is int, don't choose a huge number for the scalar
# as it'll overflow easily when converted to the other dtype
if dtype in integral_dtypes:
# Choose in range [-7, 7] ([0, 7] for uints)
low_x = 0 if dtype in uint_dtypes else -7
if low is not None:
low_x = max(low_x, low)
high_x = 7
if high is not None:
high_x = min(high_x, high)
scalar = numpy_random((), dtype_str=dtype, rs=rs, low=low_x, high=high_x).item()
if filter and filter(scalar):
# https://xkcd.com/221/
scalar = 4
else:
scalar = x.flat[0].item()
return scalar
do_test(x, y, kernel)
if mode_y != 'nan' and test_scalar:
if dtype_x in uint_dtypes:
low = 0 if y_low is None else max(y_low, 0)
else:
low = y_low
y_scalar = get_scalar(y, dtype_y, low, y_high, filter_y)
do_test(x, y_scalar, kernel_scalar_rhs)
if test_broadcast:
do_test(x[:1].reshape(()), y, kernel_broadcast_lhs)
do_test(x, y[:1].reshape(()), kernel_broadcast_rhs)
def _min_max_integral_mod_value(dtype_x, dtype_y) -> Optional[int]:
"""
Limit min/max values for integral types for mod values. Leads to
overflow/underflow when casting large integral types to floats.
"""
x_bitwidth = _bitwidth(dtype_x)
y_bitwidth = _bitwidth(dtype_y)
# hard cap max value bit-width to 32 if 64 bit-width types
min_bitwidth = min(x_bitwidth, y_bitwidth, 32)
# Limit max value bit-width to be one integral type less than the min bit-width
# For example:
# int64, float32 -> int16
# uint16, float16 -> uint8
x_dtype = _dtype(dtype_x)
max_bitwidth = max(min_bitwidth >> 1, 8)
dtype_max = x_dtype + str(max_bitwidth)
max_info = np.iinfo(getattr(np, dtype_max))
# Still need to limit values here for uints
if max_bitwidth >= 16 and dtype_max in uint_dtypes:
return max_info.min, max_info.max // 4
else:
return max_info.min, max_info.max
def test_dtype_codegen():
for dtype in dtypes_with_bfloat16:
full_name = f"triton.language.{dtype}"
assert repr(eval(full_name)) == full_name
# ---------------
# test binary ops
# ---------------
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ #
(dtype_x, dtype_y, op)
for op in ['+', '-', '*', '/', '%']
for dtype_x in dtypes_with_bfloat16
for dtype_y in dtypes_with_bfloat16
])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_bin_op(dtype_x, dtype_y, op, num_ctas, device):
expr = f'x {op} y'
np_expr_gen = (lambda x, y: f'{x} {op} {y}') if op != '%' else (lambda x, y: f'np.fmod({x}, {y})')
# Triton promotes 16-bit floating-point / and % to 32-bit because there
# are no native div or FRem operations on float16. Since we have to
# convert anyway, we may as well take the accuracy bump.
def promote_to_fp32(dtype_x, dtype_y):
return dtype_x in ('float16', 'bfloat16') and dtype_y not in ('float32', 'float64')
if op in ('/', '%') and (promote_to_fp32(dtype_x, dtype_y) or promote_to_fp32(dtype_y, dtype_x)):
numpy_expr = np_expr_gen('x.astype(np.float32)', 'y.astype(np.float32)')
elif (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = np_expr_gen(f'x.astype(np.{dtype_x})', f'y.astype(np.{dtype_x})')
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = np_expr_gen(f'x.astype(np.{dtype_y})', f'y.astype(np.{dtype_y})')
elif op == '%':
# LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders.
numpy_expr = np_expr_gen('x', 'y')
else:
numpy_expr = None
if (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or
(dtype_x in uint_dtypes and dtype_y in int_dtypes))):
with pytest.raises(triton.TritonError, match='Cannot use .* because they have different signedness'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas)
else:
# skip when bfloat16, as NumPy's ref performs the computation in float32
# while Triton performs it in bfloat16
skip_scalar_test = ((dtype_x == "bfloat16" and "float" in dtype_y)
or (op in ('/', '%') and dtype_x in ("float16", "bfloat16")))
# can't divide by zero
not_zero = op in ('/', '%') and dtype_x in integral_dtypes and dtype_y in integral_dtypes
# can't represent -int(max)
not_minus_one = op in ('*', '/') and dtype_x in int_dtypes and dtype_y in int_dtypes
if not_zero or not_minus_one:
filter_y = lambda y: not_zero * (y == 0) | not_minus_one * (y == -1)
else:
filter_y = None
if op == "%" and dtype_x in integral_dtypes and dtype_y in float_dtypes_with_bfloat16:
x_low, x_high = _min_max_integral_mod_value(dtype_x, dtype_y)
else:
x_low, x_high = None, None
_test_binary(
dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas,
# fails with values where fmod(x, y) is roughly zero, but happens to
# pass with the random values chosen for non-broadcast tests
test_broadcast=(op != "%"), x_low=x_low, x_high=x_high, filter_y=filter_y, test_scalar=not skip_scalar_test)
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]])
def test_addptr(dtype, order, device):
check_type_supported(dtype, device)
@triton.jit
def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr):
offs = tl.arange(0, SIZE)
if ORDER == 0:
tl.store(y + offs, tl.load(x + offs))
else:
tl.store(offs + y, tl.load(offs + x))
SIZE = 1024
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
x_tri = to_triton(x, dst_type=dtype, device=device)
y_tri = to_triton(y, dst_type=dtype, device=device)
y = x
kernel[
1,
](x_tri, y_tri, order, SIZE)
np.testing.assert_allclose(y, to_numpy(y_tri))
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, dtype_y", [ #
(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes
] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_floordiv(dtype_x, dtype_y, num_ctas, device):
# Triton has IEEE, not numpy/torch, semantics for %, and those carry
# through to //, so we have to use a nonstandard expression to get a
# reference result for //.
expr = 'x // y'
numpy_expr = '((x - np.fmod(x, y)) / y)'
# can't represent -int(max)
not_minus_one = dtype_x in int_dtypes and dtype_y in int_dtypes
if not_minus_one:
filter_y = lambda y: y == -1
else:
filter_y = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, filter_y=filter_y, device=device, num_ctas=num_ctas)
def test_unsigned_name_mangling(device):
# Test that uint32 and int32 are mangled differently by the compiler
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
out1 = tl.abs(x) # uint32 -> nop
out2 = tl.abs(-y) # int32 -> should have an effect
tl.store(O1 + off, out1)
tl.store(O2 + off, out2)
dtype_x = 'uint32'
dtype_y = 'int32'
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
# reference result
expect = (np.abs(x), np.abs(-y))
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect)
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
# Bitwise op, so expect exact equality
assert (expect[0] == to_numpy(actual[0])).all()
assert (expect[1] == to_numpy(actual[1])).all()
# test bitwise ops
# ---------------
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ #
(dtype_x, dtype_y, op)
for op in ['&', '|', '^']
for dtype_x in dtypes + dtypes_with_bfloat16
for dtype_y in dtypes + dtypes_with_bfloat16
])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
if 'float' in dtype_x + dtype_y:
# The CompilationError must have been caused by a C++ exception with this text.
with pytest.raises(triton.TritonError, match='invalid operands of type'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr='np.array([])', device=device, num_ctas=num_ctas)
else:
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas)
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, dtype_y, op", [ #
(dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes for dtype_y in uint_dtypes
])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_shift_op(dtype_x, dtype_y, op, num_ctas, device):
expr = f'x {op} y'
bw = max(_bitwidth(dtype_x), _bitwidth(dtype_y))
if dtype_x.startswith('int'):
dtype_z = f'int{bw}'
else:
dtype_z = f'uint{bw}'
numpy_expr = f'x.astype(np.{dtype_z}) {op} y.astype(np.{dtype_z})'
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas, y_low=0, y_high=bw)
# ---------------
# test compare ops
# ---------------
ops = ['==', '!=', '>', '<', '>=', '<=']
@pytest.mark.interpreter
@pytest.mark.parametrize(
"dtype_x, dtype_y, op, mode_x, mode_y",
# real
[(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes]
# NaNs
+ [('float32', 'float32', op, mode_x, mode_y)
for op in ops
for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device):
expr = f'x {op} y'
if (dtype_x in uint_dtypes and dtype_y in int_dtypes and _bitwidth(dtype_x) >= _bitwidth(dtype_y)):
numpy_expr = f'x.astype(np.{dtype_x}) {op} y.astype(np.{dtype_x})'
elif (dtype_y in uint_dtypes and dtype_x in int_dtypes and _bitwidth(dtype_y) >= _bitwidth(dtype_x)):
numpy_expr = f'x.astype(np.{dtype_y}) {op} y.astype(np.{dtype_y})'
else:
numpy_expr = None
_test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas)
# ---------------
# test broadcast
# ---------------
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype, device):
check_type_supported(dtype, device)
@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
y = tl.load(y_ptr + offset2)
_, y_broadcasted = tl.broadcast(x, y)
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
M = 32
N = 64
rs = RandomState(17)
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)
x_tri = to_triton(x, device=device, dst_type=dtype)
y_tri = to_triton(y, device=device, dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype)
broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ----------
# test slice
# ----------
@pytest.mark.interpreter
def test_slice(device):
@triton.jit
def slice_kernel(XBLOCK: tl.constexpr):
data = tl.arange(0, XBLOCK)
tl.static_assert(data.shape == [XBLOCK])
t = data[None, :]
tl.static_assert(t.shape == [1, XBLOCK])
t = data[None, :, None]
tl.static_assert(t.shape == [1, XBLOCK, 1])
scalar = tl.full([], 1, tl.int32)
tl.static_assert(scalar.shape == [])
t = scalar[None]
tl.static_assert(t.shape == [1])
t = scalar[None, None]
tl.static_assert(t.shape == [1, 1])
slice_kernel[(1, )](XBLOCK=32)
# ------------------
# test invalid slice
# ------------------
@pytest.mark.interpreter
def test_invalid_slice(device):
dst = torch.empty(128, device=device)
@triton.jit
def _kernel(dst):
dst[10:]
with pytest.raises(triton.TritonError, match='unsupported tensor index'):
_kernel[(1, )](dst=dst)
# ----------------
# test expand_dims
# ----------------
@pytest.mark.interpreter
def test_expand_dims(device):
@triton.jit
def expand_dims_kernel(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, 0)
tl.static_assert(t.shape == [1, N])
t = tl.expand_dims(offset1, 1)
tl.static_assert(t.shape == [N, 1])
t = tl.expand_dims(offset1, -1)
tl.static_assert(t.shape == [N, 1])
t = tl.expand_dims(offset1, -2)
tl.static_assert(t.shape == [1, N])
t = tl.expand_dims(offset1, (0, -1))
tl.static_assert(t.shape == [1, N, 1])
t = tl.expand_dims(offset1, (0, 1, 3))
tl.static_assert(t.shape == [1, 1, N, 1])
t = tl.expand_dims(offset1, (-4, 2, -1))
tl.static_assert(t.shape == [1, N, 1, 1])
t = tl.expand_dims(offset1, (3, 1, 2))
tl.static_assert(t.shape == [N, 1, 1, 1])
scalar = tl.sum(offset1)
tl.static_assert(scalar.shape == [])
t = tl.expand_dims(scalar, 0)
tl.static_assert(t.shape == [1])
t = tl.expand_dims(scalar, -1)
tl.static_assert(t.shape == [1])
# N is a scalar that's not even a tl.tensor -- this should work too.
t = tl.expand_dims(N, -1)
tl.static_assert(t.shape == [1])
N = 32
dummy_tensor = torch.empty((), device=device)
expand_dims_kernel[(1, )](dummy_tensor, N)
@pytest.mark.interpreter
def test_expand_dims_error_cases(device):
@triton.jit
def dim_out_of_range1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, -2)
t = tl.expand_dims(offset1, -3)
@triton.jit
def dim_out_of_range2(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, 1)
t = tl.expand_dims(offset1, 2)
@triton.jit
def dim_out_of_range3(dummy, N: tl.constexpr):
offset1 = tl.arange(0, 1)
scalar = tl.sum(offset1)
t = tl.expand_dims(scalar, 1)
@triton.jit
def duplicate_dim1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, (0, 0))
@triton.jit
def duplicate_dim2(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, (0, -3))
N = 32
dummy_tensor = torch.empty((), device=device)
with pytest.raises(triton.TritonError) as exc_info:
dim_out_of_range1[(1, )](dummy_tensor, N)
assert "invalid axis -3" in str(exc_info.value.__cause__)
with pytest.raises(triton.TritonError) as exc_info:
dim_out_of_range2[(1, )](dummy_tensor, N)
assert "invalid axis 2" in str(exc_info.value.__cause__)
with pytest.raises(triton.TritonError) as exc_info:
dim_out_of_range3[(1, )](dummy_tensor, N)
assert "invalid axis 1" in str(exc_info.value.__cause__)
with pytest.raises(triton.TritonError) as exc_info:
duplicate_dim1[(1, )](dummy_tensor, N)
assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__))
with pytest.raises(triton.TritonError) as exc_info:
duplicate_dim2[(1, )](dummy_tensor, N)
assert re.search(r"duplicate axes, normalized axes = \[0, 0\]", str(exc_info.value.__cause__))
# ----------------------------
# test invalid program id axis
# ----------------------------
@pytest.mark.interpreter
def test_invalid_pid_axis(device):
dst = torch.empty(128, device=device)
@triton.jit
def _kernel(dst):
pid = tl.program_id(20)
with pytest.raises(triton.TritonError) as exc_info:
_kernel[(1, )](dst)
assert re.search(r"program_id axis must be 0, 1, or 2 but got 20", str(exc_info.value.__cause__))
# ---------------
# test where
# ---------------
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16 + ["*int32"])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_where(dtype, num_ctas, device):
select_ptrs = False
if dtype == "*int32":
dtype = "int64"
select_ptrs = True
check_type_supported(dtype, device)
@triton.jit
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr,
TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr):
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
decide = tl.load(cond_ptr + offsets, mask=mask)
if TEST_SCALAR_POINTERS:
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
output = tl.load(ptr + offsets, mask=mask)
else:
if TEST_POINTERS:
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
else:
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
output = tl.where(decide, a, b)
tl.store(output_ptr + offsets, output, mask=mask)
SIZE = 1_000
rs = RandomState(17)
cond = numpy_random(SIZE, 'bool', rs)
x = numpy_random(SIZE, dtype_str=dtype, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype, rs=rs)
z = np.where(cond, x, y)
cond_tri = to_triton(cond, device=device)
x_tri = to_triton(x, device=device, dst_type=dtype)
y_tri = to_triton(y, device=device, dst_type=dtype)
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype)
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), )
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs,
TEST_SCALAR_POINTERS=False, num_ctas=num_ctas)
assert (z == to_numpy(z_tri)).all()
if select_ptrs:
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs,
TEST_SCALAR_POINTERS=True)
z = np.where(cond[0], x, y)
assert (z == to_numpy(z_tri)).all()
@pytest.mark.interpreter
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_where_broadcast(num_ctas, device):
@triton.jit
def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = tl.load(cond_ptr + yoffsets)
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
@triton.jit
def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr):
xoffsets = tl.arange(0, BLOCK_SIZE)[:, None]
yoffsets = tl.arange(0, BLOCK_SIZE)[None, :]
mask = False
vals = tl.load(a_ptr + yoffsets + BLOCK_SIZE * xoffsets)
res = tl.where(mask, vals, 0.)
tl.store(out_ptr + yoffsets + BLOCK_SIZE * xoffsets, res)
SIZE = 32
dtype = 'float32'
rs = RandomState(17)
x = numpy_random((SIZE, SIZE), dtype_str=dtype, rs=rs)
mask = numpy_random(SIZE, 'bool', rs=rs)
z = np.where(mask, x, 0)
cond_tri = to_triton(mask, device=device)
x_tri = to_triton(x, device=device, dst_type=dtype)
z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype)
where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE)
assert (z == to_numpy(z_tri)).all()
where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas)
z = np.where(0, x, 0)
assert (z == to_numpy(z_tri)).all()
# ---------------
# test unary ops
# ---------------
@pytest.mark.interpreter
@pytest.mark.parametrize("dtype_x, expr",
[(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x')
for dtype_x in int_dtypes])
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_unary_op(dtype_x, expr, num_ctas, device):
_test_unary(dtype_x, expr, device=device, num_ctas=num_ctas)