Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check tile sizes in config #180

Merged
merged 1 commit into from
Dec 7, 2023
Merged

Check tile sizes in config #180

merged 1 commit into from
Dec 7, 2023

Conversation

thomasfaingnaert
Copy link
Member

Extracted from #179

@maleadt
Copy link
Member

maleadt commented Dec 7, 2023

Benchmark results for commit 7e244b0 (comparing to 5871d06):

test master PR Δmin
Tropical GEMM Float32*Float32=Float32 (128×256) · (256×128) (TN) OP (8, 16, 2), base shape (4, 8, 1) 14.6 μs ± 0.981% (14.5 … 15.0 μs)
80 regs
124.0 μs ± 0.148% (123.0 … 124.0 μs)
116 regs
+747.5% ❌
WMMA GEMM Float16*Float16+Float16=Float16 (2048×2048) · (2048×2048) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 217.0 μs ± 3.93% (205.0 … 229.0 μs) 226.0 μs ± 0.862% (221.0 … 231.0 μs) +7.9% ❌
Tropical GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (8, 16, 2), base shape (4, 8, 1) 9.18 μs ± 1.59% (8.82 … 9.54 μs)
80 regs
63.6 μs ± 0.266% (63.2 … 64.1 μs)
118 regs
+616.2% ❌
Tropical GEMM Float32*Float32=Float32 (256×128) · (128×256) (TT) OP (8, 16, 2), base shape (4, 8, 1) 9.57 μs ± 1.72% (9.3 … 9.78 μs)
82 regs
64.0 μs ± 0.25% (63.7 … 64.4 μs)
123 regs
+584.6% ❌
Tropical GEMM Float32*Float32=Float32 (256×256) · (256×256) (TT) OP (8, 16, 2), base shape (4, 8, 1) 16.3 μs ± 1.01% (16.0 … 16.7 μs)
82 regs
118.0 μs ± 6.8% (110.0 … 128.0 μs)
123 regs
+586.6% ❌
Tropical GEMM Float32*Float32=Float32 (128×256) · (256×128) (NT) OP (8, 16, 2), base shape (4, 8, 1) 18.3 μs ± 0.839% (18.1 … 18.6 μs)
80 regs
126.0 μs ± 3.62% (111.0 … 128.0 μs)
111 regs
+514.5% ❌
Tropical GEMM Float32*Float32=Float32 (128×256) · (256×128) (TT) OP (8, 16, 2), base shape (4, 8, 1) 16.1 μs ± 0.92% (15.7 … 16.5 μs)
82 regs
125.0 μs ± 0.124% (125.0 … 126.0 μs)
123 regs
+692.4% ❌
Tropical GEMM Float32*Float32=Float32 (256×256) · (256×256) (NT) OP (8, 16, 2), base shape (4, 8, 1) 18.5 μs ± 0.836% (18.4 … 18.8 μs)
80 regs
112.0 μs ± 0.146% (111.0 … 112.0 μs)
111 regs
+506.5% ❌
Tropical GEMM Float32*Float32=Float32 (256×128) · (128×256) (NT) OP (8, 16, 2), base shape (4, 8, 1) 10.7 μs ± 1.32% (10.5 … 11.0 μs)
80 regs
64.9 μs ± 0.264% (64.4 … 65.3 μs)
111 regs
+513.6% ❌
Tropical GEMM Float32*Float32=Float32 (256×256) · (256×256) (TN) OP (8, 16, 2), base shape (4, 8, 1) 14.8 μs ± 0.957% (14.5 … 15.0 μs)
80 regs
121.0 μs ± 5.09% (108.0 … 124.0 μs)
116 regs
+645.9% ❌
Tropical GEMM Float32*Float32=Float32 (256×128) · (128×256) (TN) OP (8, 16, 2), base shape (4, 8, 1) 8.81 μs ± 1.97% (8.58 … 9.06 μs)
80 regs
63.2 μs ± 0.22% (62.9 … 63.7 μs)
116 regs
+633.3% ❌
Tropical GEMM Float32*Float32=Float32 (128×256) · (256×128) (NN) OP (8, 16, 2), base shape (4, 8, 1) 15.7 μs ± 1.08% (15.5 … 16.0 μs)
80 regs
124.0 μs ± 2.99% (110.0 … 125.0 μs)
118 regs
+607.7% ❌
Tropical GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (8, 16, 2), base shape (4, 8, 1) 9.41 μs ± 1.61% (9.06 … 9.78 μs)
82 regs
56.0 μs ± 0.259% (55.8 … 56.5 μs)
123 regs
+515.8% ❌
Tropical GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (8, 16, 2), base shape (4, 8, 1) 8.63 μs ± 1.94% (8.34 … 8.82 μs)
80 regs
55.4 μs ± 0.242% (54.8 … 55.8 μs)
116 regs
+557.1% ❌
Tropical GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (8, 16, 2), base shape (4, 8, 1) 10.5 μs ± 0.695% (10.0 … 11.0 μs)
80 regs
57.0 μs ± 0.3% (56.7 … 57.5 μs)
111 regs
+466.7% ❌
Tropical GEMM Float32*Float32=Float32 (256×256) · (256×256) (NN) OP (8, 16, 2), base shape (4, 8, 1) 15.9 μs ± 0.91% (15.5 … 16.2 μs)
80 regs
110.0 μs ± 0.148% (110.0 … 111.0 μs)
118 regs
+607.7% ❌
Tropical GEMM Float32*Float32=Float32 (256×128) · (128×256) (NN) OP (8, 16, 2), base shape (4, 8, 1) 9.36 μs ± 1.83% (9.06 … 9.54 μs)
80 regs
63.8 μs ± 0.255% (63.4 … 64.4 μs)
118 regs
+600.0% ❌

Comparison with baseline

test GemmKernels Baseline %
FPU GEMM Float32*Float32=Float32 (256×256) · (256×256) (TT) OP (8, 16, 2), base shape (4, 8, 1) 15.0 μs ± 1.13% (14.8 … 15.3 μs) 7.23 μs ± 1.77% (6.91 … 7.63 μs) 46.8
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (4, 32, 1), base shape (1, 32, 1) 16.0 μs ± 1.03% (15.7 … 16.2 μs) 4.21 μs ± 3.87% (4.05 … 4.53 μs) 25.8
WMMA GEMM Float16*Float16+Float16=Float16 (2048×2048) · (2048×2048) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 232.0 μs ± 2.21% (210.0 … 238.0 μs) 191.0 μs ± 2.18% (181.0 … 195.0 μs) 86.0
WMMA GEMM Float16*Float16+Float32=Float32 (128×128) · (128×128) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 7.92 μs ± 1.62% (7.63 … 8.34 μs) 2.38 μs ± 6.71% (2.15 … 2.62 μs) 28.1
WMMA GEMM Float16*Float16+Float16=Float16 (2048×2048) · (2048×2048) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 218.0 μs ± 3.72% (205.0 … 230.0 μs) 183.0 μs ± 2.3% (173.0 … 194.0 μs) 84.2
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (16, 16, 1), base shape (8, 4, 1) 14.5 μs ± 1.09% (14.3 … 14.8 μs) 4.03 μs ± 3.37% (3.81 … 4.29 μs) 26.7
FPU GEMM Float32*Float32=Float32 (2048×2048) · (2048×2048) (NN) OP (8, 16, 2), base shape (4, 8, 1) 1.84 ms ± 2.78% (1.81 … 2.02 ms) 994.0 μs ± 1.77% (932.0 … 1020.0 μs) 51.5
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (32, 4, 1), base shape (32, 1, 1) 13.1 μs ± 0.888% (12.6 … 13.4 μs) 4.04 μs ± 3.44% (3.81 … 4.29 μs) 30.2
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (32, 4, 1), base shape (16, 2, 1) 19.1 μs ± 0.908% (18.8 … 19.3 μs) 4.1 μs ± 2.49% (3.81 … 4.53 μs) 20.3
FPU GEMM Float16*Float16=Float32 (2048×2048) · (2048×2048) (TT) OP (8, 16, 2), base shape (4, 8, 1) 1.84 ms ± 2.76% (1.8 … 2.02 ms) 292.0 μs ± 0.415% (289.0 … 294.0 μs) 16.1
FPU GEMM Float64*Float64=Float64 (2048×2048) · (2048×2048) (TN) OP (8, 16, 2), base shape (4, 8, 1) 46.2 ms ± 0.616% (46.0 … 52.4 ms) 41.0 ms ± 0.00241% (41.0 … 41.0 ms) 89.2
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (32, 4, 1), base shape (16, 2, 1) 19.1 μs ± 0.903% (18.8 … 19.3 μs) 4.3 μs ± 3.39% (4.05 … 4.53 μs) 21.5
FPU GEMM Float64*Float64=Float64 (256×256) · (256×256) (TN) OP (8, 16, 2), base shape (4, 8, 1) 270.0 μs ± 6.17% (256.0 … 293.0 μs) 109.0 μs ± 0.154% (109.0 … 110.0 μs) 42.7
WMMA GEMM Float16*Float16+Float16=Float16 (256×256) · (256×256) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 7.42 μs ± 1.9% (7.15 … 7.87 μs) 3.84 μs ± 2.94% (3.58 … 4.29 μs) 50.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (8, 16, 2), base shape (4, 8, 1) 15.3 μs ± 1.16% (15.0 … 15.7 μs) 4.09 μs ± 2.32% (3.81 … 4.53 μs) 25.4
WMMA GEMM Float16*Float16+Float16=Float16 (512×256) · (256×512) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.71 μs ± 2.02% (8.34 … 9.06 μs) 5.09 μs ± 2.86% (4.77 … 5.48 μs) 57.1
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (4, 8, 1), base shape (4, 8, 1) 15.2 μs ± 1.16% (15.0 … 15.5 μs) 4.32 μs ± 3.27% (4.05 … 4.77 μs) 27.0
FPU GEMM Float32*Float32=Float32 (2048×2048) · (2048×2048) (TT) OP (8, 16, 2), base shape (4, 8, 1) 1.89 ms ± 3.25% (1.84 … 2.05 ms) 1.09 ms ± 1.07% (1.02 … 1.12 ms) 55.8
FPU GEMM Float64*Float64=Float64 (256×256) · (256×256) (NN) OP (8, 16, 2), base shape (4, 8, 1) 285.0 μs ± 4.76% (257.0 … 292.0 μs) 109.0 μs ± 0.156% (109.0 … 110.0 μs) 42.5
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (4, 8, 1), base shape (4, 8, 1) 15.2 μs ± 1.08% (15.0 … 15.5 μs) 4.03 μs ± 3.48% (3.81 … 4.29 μs) 25.4
WMMA GEMM Float16*Float16+Float32=Float32 (2048×2048) · (2048×2048) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 358.0 μs ± 4.68% (343.0 … 391.0 μs) 292.0 μs ± 0.163% (291.0 … 294.0 μs) 84.8
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (16, 16, 1), base shape (4, 8, 1) 15.3 μs ± 0.874% (15.0 … 15.5 μs) 4.09 μs ± 2.48% (3.81 … 4.53 μs) 25.4
FPU GEMM Float32*Float32=Float32 (2048×2048) · (2048×2048) (TN) OP (8, 16, 2), base shape (4, 8, 1) 1.45 ms ± 1.68% (1.43 … 1.53 ms) 1.1 ms ± 1.68% (1.03 … 1.13 ms) 72.4
WMMA GEMM Float16*Float16+Float32=Float32 (128×128) · (128×128) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.08 μs ± 2.11% (7.87 … 8.34 μs) 2.21 μs ± 7.56% (1.91 … 2.38 μs) 24.2
FPU GEMM Float16*Float16=Float32 (2048×2048) · (2048×2048) (NN) OP (8, 16, 2), base shape (4, 8, 1) 1.85 ms ± 2.77% (1.82 … 2.03 ms) 293.0 μs ± 0.253% (291.0 … 294.0 μs) 16.0
FPU GEMM Float64*Float64=Float64 (128×128) · (128×128) (TT) OP (8, 16, 2), base shape (4, 8, 1) 143.0 μs ± 5.56% (129.0 … 150.0 μs) 28.8 μs ± 0.542% (28.6 … 29.1 μs) 22.1
WMMA GEMM Float16*Float16+Float32=Float32 (128×128) · (128×128) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.02 μs ± 2.01% (7.87 … 8.34 μs) 2.49 μs ± 6.35% (2.15 … 2.62 μs) 27.3
FPU GEMM Float64*Float64=Float64 (2048×2048) · (2048×2048) (TT) OP (8, 16, 2), base shape (4, 8, 1) 46.3 ms ± 0.945% (46.0 … 52.4 ms) 41.0 ms ± 0.00309% (41.0 … 41.0 ms) 89.2
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (32, 4, 1), base shape (32, 1, 1) 12.6 μs ± 1.27% (12.4 … 12.9 μs) 4.22 μs ± 3.81% (4.05 … 4.53 μs) 32.7
WMMA GEMM Float16*Float16+Float32=Float32 (2048×2048) · (2048×2048) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 366.0 μs ± 3.98% (346.0 … 396.0 μs) 291.0 μs ± 0.203% (290.0 … 293.0 μs) 83.7
FPU GEMM Float16*Float16=Float32 (256×256) · (256×256) (TN) OP (8, 16, 2), base shape (4, 8, 1) 13.0 μs ± 1.13% (12.6 … 13.1 μs) 3.79 μs ± 3.14% (3.58 … 4.29 μs) 28.3
WMMA GEMM Float16*Float16+Float16=Float16 (256×256) · (256×256) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.43 μs ± 1.91% (8.11 … 8.82 μs) 3.92 μs ± 3.66% (3.81 … 4.29 μs) 47.1
WMMA GEMM Float16*Float16+Float32=Float32 (2048×2048) · (2048×2048) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 369.0 μs ± 4.92% (343.0 … 390.0 μs) 292.0 μs ± 0.299% (290.0 … 294.0 μs) 84.7
FPU GEMM Float16*Float16=Float32 (2048×2048) · (2048×2048) (NT) OP (8, 16, 2), base shape (4, 8, 1) 2.29 ms ± 2.74% (2.26 … 2.56 ms) 292.0 μs ± 0.245% (290.0 … 294.0 μs) 12.8
FPU GEMM Float16*Float16=Float32 (128×128) · (128×128) (TN) OP (8, 16, 2), base shape (4, 8, 1) 6.91 μs ± 0.859% (6.44 … 7.39 μs) 2.12 μs ± 7.8% (1.91 … 2.38 μs) 29.6
WMMA GEMM Float16*Float16+Float32=Float32 (256×128) · (128×256) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 9.08 μs ± 1.96% (8.82 … 9.3 μs) 2.82 μs ± 4.16% (2.38 … 3.1 μs) 27.0
WMMA GEMM Float16*Float16+Float32=Float32 (128×256) · (256×128) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.3 μs ± 1.2% (12.2 … 12.6 μs) 2.92 μs ± 5.18% (2.62 … 3.1 μs) 21.6
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (4, 8, 2), base shape (4, 8, 1) 16.1 μs ± 1.06% (15.7 … 16.5 μs) 4.21 μs ± 3.81% (4.05 … 4.53 μs) 25.8
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (8, 16, 2), base shape (4, 8, 1) 16.4 μs ± 0.97% (16.0 … 16.7 μs) 4.22 μs ± 3.97% (4.05 … 4.53 μs) 25.4
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (32, 4, 1), base shape (32, 1, 1) 31.6 μs ± 0.464% (31.5 … 31.9 μs) 4.09 μs ± 2.47% (3.81 … 4.53 μs) 12.1
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (8, 16, 2), base shape (4, 8, 1) 13.6 μs ± 0.951% (13.1 … 13.8 μs) 4.04 μs ± 3.42% (3.81 … 4.29 μs) 29.1
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (4, 32, 1), base shape (1, 32, 1) 15.7 μs ± 1.06% (15.5 … 16.0 μs) 4.1 μs ± 2.41% (3.81 … 4.53 μs) 24.6
FPU GEMM Float16*Float16=Float32 (256×256) · (256×256) (NT) OP (8, 16, 2), base shape (4, 8, 1) 18.1 μs ± 0.87% (17.9 … 18.4 μs) 4.0 μs ± 3.84% (3.81 … 4.53 μs) 21.3
WMMA GEMM Float16*Float16+Float16=Float16 (2048×2048) · (2048×2048) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 226.0 μs ± 0.862% (221.0 … 231.0 μs) 188.0 μs ± 2.25% (178.0 … 195.0 μs) 80.4
WMMA GEMM Float16*Float16+Float16=Float16 (2048×2048) · (2048×2048) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 230.0 μs ± 1.07% (216.0 … 236.0 μs) 183.0 μs ± 2.65% (174.0 … 190.0 μs) 80.7
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (4, 32, 1), base shape (2, 16, 1) 21.9 μs ± 0.708% (21.7 … 22.2 μs) 4.31 μs ± 3.28% (4.05 … 4.77 μs) 18.7
WMMA GEMM Float16*Float16+Float32=Float32 (128×256) · (256×128) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.2 μs ± 1.44% (11.9 … 12.4 μs) 3.05 μs ± 3.34% (2.62 … 3.1 μs) 22.0
FPU GEMM Float64*Float64=Float64 (128×128) · (128×128) (NT) OP (8, 16, 2), base shape (4, 8, 1) 131.0 μs ± 0.233% (130.0 … 132.0 μs) 28.8 μs ± 0.56% (28.6 … 29.1 μs) 22.1
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (16, 16, 1), base shape (8, 4, 1) 14.5 μs ± 1.06% (14.3 … 14.8 μs) 4.09 μs ± 2.35% (3.81 … 4.53 μs) 26.7
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (4, 8, 2), base shape (4, 8, 1) 12.4 μs ± 1.36% (12.2 … 12.6 μs) 4.05 μs ± 3.44% (3.81 … 4.29 μs) 31.4
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (16, 16, 1), base shape (4, 8, 1) 15.4 μs ± 1.09% (15.3 … 15.7 μs) 4.04 μs ± 3.37% (3.81 … 4.53 μs) 25.0
WMMA GEMM Float16*Float16+Float16=Float16 (256×512) · (512×256) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.0 μs ± 1.39% (11.7 … 12.4 μs) 5.56 μs ± 3.79% (5.01 … 5.96 μs) 42.9
WMMA GEMM Float16*Float16+Float16=Float16 (512×512) · (512×512) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.3 μs ± 1.47% (11.9 … 12.6 μs) 7.26 μs ± 1.89% (7.15 … 7.63 μs) 60.0
WMMA GEMM Float16*Float16+Float16=Float16 (512×256) · (256×512) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.75 μs ± 2.29% (8.34 … 9.06 μs) 4.99 μs ± 3.5% (4.77 … 5.25 μs) 57.1
FPU GEMM Float64*Float64=Float64 (128×128) · (128×128) (NN) OP (8, 16, 2), base shape (4, 8, 1) 142.0 μs ± 5.85% (130.0 … 149.0 μs) 28.8 μs ± 0.563% (28.6 … 29.1 μs) 22.1
WMMA GEMM Float16*Float16+Float32=Float32 (256×256) · (256×256) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.4 μs ± 1.38% (12.2 … 12.6 μs) 3.78 μs ± 3.4% (3.58 … 4.29 μs) 29.4
WMMA GEMM Float16*Float16+Float32=Float32 (256×256) · (256×256) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.3 μs ± 1.2% (11.9 … 12.4 μs) 3.86 μs ± 3.74% (3.58 … 4.29 μs) 30.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (4, 32, 1), base shape (2, 16, 1) 22.0 μs ± 0.703% (21.7 … 22.2 μs) 4.03 μs ± 3.51% (3.81 … 4.29 μs) 17.6
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (32, 4, 1), base shape (16, 2, 1) 19.1 μs ± 0.908% (18.8 … 19.3 μs) 4.04 μs ± 3.46% (3.81 … 4.29 μs) 20.3
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (4, 8, 2), base shape (4, 8, 1) 11.8 μs ± 1.36% (11.7 … 12.2 μs) 4.3 μs ± 3.37% (4.05 … 4.53 μs) 34.7
WMMA GEMM Float16*Float16+Float16=Float16 (256×512) · (512×256) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.2 μs ± 1.46% (11.9 … 12.4 μs) 5.51 μs ± 3.17% (5.01 … 5.72 μs) 42.0
WMMA GEMM Float16*Float16+Float16=Float16 (512×256) · (256×512) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.66 μs ± 1.93% (8.34 … 9.06 μs) 5.11 μs ± 2.68% (4.77 … 5.48 μs) 57.1
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (4, 16, 1), base shape (4, 8, 1) 15.2 μs ± 1.13% (15.0 … 15.5 μs) 4.31 μs ± 3.32% (4.05 … 4.77 μs) 27.0
WMMA GEMM Float16*Float16+Float32=Float32 (256×256) · (256×256) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.3 μs ± 1.16% (11.9 … 12.6 μs) 3.91 μs ± 3.58% (3.58 … 4.29 μs) 30.0
FPU GEMM Float64*Float64=Float64 (2048×2048) · (2048×2048) (NT) OP (8, 16, 2), base shape (4, 8, 1) 46.6 ms ± 0.74% (46.3 … 52.8 ms) 41.0 ms ± 0.00261% (41.0 … 41.0 ms) 88.5
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (8, 8, 1), base shape (4, 8, 1) 15.4 μs ± 0.891% (15.0 … 15.7 μs) 4.1 μs ± 2.45% (3.81 … 4.53 μs) 25.4
WMMA GEMM Float16*Float16+Float16=Float16 (512×256) · (256×512) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.76 μs ± 2.16% (8.34 … 9.06 μs) 4.88 μs ± 2.91% (4.53 … 5.25 μs) 54.3
WMMA GEMM Float16*Float16+Float16=Float16 (512×512) · (512×512) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.4 μs ± 1.4% (12.2 … 12.6 μs) 7.46 μs ± 1.67% (7.15 … 7.87 μs) 58.8
WMMA GEMM Float16*Float16+Float32=Float32 (2048×2048) · (2048×2048) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 374.0 μs ± 5.06% (350.0 … 396.0 μs) 293.0 μs ± 0.317% (291.0 … 295.0 μs) 83.0
WMMA GEMM Float16*Float16+Float16=Float16 (512×512) · (512×512) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.4 μs ± 1.57% (11.9 … 12.9 μs) 7.32 μs ± 2.09% (7.15 … 7.63 μs) 60.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (4, 8, 1), base shape (4, 8, 1) 13.4 μs ± 1.06% (13.1 … 13.6 μs) 4.09 μs ± 2.38% (3.81 … 4.53 μs) 29.1
FPU GEMM Float16*Float16=Float32 (128×128) · (128×128) (NT) OP (8, 16, 2), base shape (4, 8, 1) 9.2 μs ± 1.57% (8.82 … 9.54 μs) 2.25 μs ± 7.06% (1.91 … 2.38 μs) 21.6
WMMA GEMM Float16*Float16+Float32=Float32 (256×128) · (128×256) (NN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 9.27 μs ± 1.67% (9.06 … 9.54 μs) 2.9 μs ± 5.03% (2.62 … 3.1 μs) 28.9
WMMA GEMM Float16*Float16+Float32=Float32 (128×128) · (128×128) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 8.03 μs ± 2.05% (7.87 … 8.34 μs) 2.23 μs ± 7.41% (1.91 … 2.38 μs) 24.2
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (16, 16, 1), base shape (8, 4, 1) 14.6 μs ± 0.904% (14.3 … 15.0 μs) 4.19 μs ± 3.59% (4.05 … 4.53 μs) 28.3
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (4, 32, 1), base shape (1, 32, 1) 33.0 μs ± 0.414% (32.7 … 33.1 μs) 4.04 μs ± 3.47% (3.81 … 4.29 μs) 11.7
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (4, 16, 1), base shape (4, 8, 1) 15.2 μs ± 1.15% (15.0 … 15.5 μs) 4.04 μs ± 3.52% (3.81 … 4.29 μs) 25.4
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (32, 4, 1), base shape (16, 2, 1) 19.1 μs ± 0.861% (18.8 … 19.3 μs) 4.23 μs ± 4.11% (4.05 … 4.53 μs) 21.5
FPU GEMM Float32*Float32=Float32 (2048×2048) · (2048×2048) (NT) OP (8, 16, 2), base shape (4, 8, 1) 2.36 ms ± 2.92% (2.32 … 2.63 ms) 851.0 μs ± 1.45% (804.0 … 864.0 μs) 34.6
FPU GEMM Float64*Float64=Float64 (256×256) · (256×256) (TT) OP (8, 16, 2), base shape (4, 8, 1) 277.0 μs ± 6.23% (256.0 … 295.0 μs) 109.0 μs ± 0.156% (109.0 … 110.0 μs) 42.6
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TT) OP (8, 8, 1), base shape (4, 8, 1) 15.3 μs ± 0.947% (15.0 … 15.5 μs) 4.04 μs ± 3.32% (3.81 … 4.29 μs) 25.4
WMMA GEMM Float16*Float16+Float16=Float16 (256×256) · (256×256) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 7.48 μs ± 1.82% (7.15 … 7.87 μs) 3.75 μs ± 3.78% (3.58 … 4.05 μs) 50.0
WMMA GEMM Float16*Float16+Float32=Float32 (128×256) · (256×128) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.2 μs ± 1.45% (11.9 … 12.4 μs) 3.08 μs ± 2.79% (2.86 … 3.34 μs) 24.0
FPU GEMM Float16*Float16=Float32 (128×128) · (128×128) (TT) OP (8, 16, 2), base shape (4, 8, 1) 8.19 μs ± 1.53% (7.87 … 8.58 μs) 2.19 μs ± 7.9% (1.91 … 2.38 μs) 24.2
WMMA GEMM Float16*Float16+Float16=Float16 (512×512) · (512×512) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.4 μs ± 1.36% (12.2 … 12.6 μs) 7.43 μs ± 1.54% (7.15 … 7.87 μs) 58.8
WMMA GEMM Float16*Float16+Float32=Float32 (256×256) · (256×256) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.4 μs ± 1.43% (12.2 … 12.6 μs) 4.0 μs ± 3.86% (3.81 … 4.53 μs) 31.4
FPU GEMM Float32*Float32=Float32 (256×256) · (256×256) (NN) OP (8, 16, 2), base shape (4, 8, 1) 15.2 μs ± 1.14% (15.0 … 15.5 μs) 6.83 μs ± 1.82% (6.44 … 6.91 μs) 42.9
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (8, 8, 1), base shape (4, 8, 1) 15.4 μs ± 0.984% (15.3 … 15.7 μs) 4.21 μs ± 3.83% (4.05 … 4.53 μs) 26.6
FPU GEMM Float64*Float64=Float64 (256×256) · (256×256) (NT) OP (8, 16, 2), base shape (4, 8, 1) 291.0 μs ± 2.14% (257.0 … 293.0 μs) 109.0 μs ± 0.161% (109.0 … 110.0 μs) 42.5
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (4, 16, 1), base shape (4, 8, 1) 15.3 μs ± 1.02% (15.0 … 15.7 μs) 4.22 μs ± 3.74% (4.05 … 4.53 μs) 27.0
WMMA GEMM Float16*Float16+Float32=Float32 (256×128) · (128×256) (TT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 9.22 μs ± 1.45% (8.82 … 9.54 μs) 2.83 μs ± 4.12% (2.38 … 3.1 μs) 27.0
FPU GEMM Float16*Float16=Float32 (128×128) · (128×128) (NN) OP (8, 16, 2), base shape (4, 8, 1) 9.22 μs ± 1.38% (8.82 … 9.54 μs) 2.21 μs ± 7.78% (1.91 … 3.1 μs) 21.6
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (8, 8, 1), base shape (4, 8, 1) 15.2 μs ± 1.13% (15.0 … 15.5 μs) 4.3 μs ± 3.21% (4.05 … 4.53 μs) 27.0
FPU GEMM Float16*Float16=Float32 (256×256) · (256×256) (TT) OP (8, 16, 2), base shape (4, 8, 1) 15.8 μs ± 1.05% (15.5 … 16.0 μs) 3.87 μs ± 3.68% (3.58 … 4.29 μs) 23.1
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (16, 16, 1), base shape (4, 8, 1) 15.3 μs ± 1.04% (15.0 … 15.5 μs) 4.19 μs ± 3.57% (4.05 … 4.53 μs) 27.0
WMMA GEMM Float16*Float16+Float32=Float32 (128×256) · (256×128) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.4 μs ± 1.37% (12.2 … 12.6 μs) 2.65 μs ± 6.08% (2.38 … 2.86 μs) 19.6
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (4, 8, 2), base shape (4, 8, 1) 15.0 μs ± 1.14% (14.5 … 15.5 μs) 4.09 μs ± 2.42% (3.81 … 4.53 μs) 26.2
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (32, 4, 1), base shape (32, 1, 1) 31.6 μs ± 0.462% (31.5 … 31.9 μs) 4.31 μs ± 3.23% (4.05 … 4.53 μs) 12.9
WMMA GEMM Float16*Float16+Float32=Float32 (256×128) · (128×256) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 9.17 μs ± 1.66% (8.82 … 9.3 μs) 2.94 μs ± 5.24% (2.62 … 3.1 μs) 29.7
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (4, 32, 1), base shape (2, 16, 1) 21.9 μs ± 0.774% (21.7 … 22.2 μs) 4.2 μs ± 3.92% (4.05 … 4.53 μs) 18.7
FPU GEMM Float16*Float16=Float32 (256×256) · (256×256) (NN) OP (8, 16, 2), base shape (4, 8, 1) 15.8 μs ± 1.11% (15.5 … 16.0 μs) 3.92 μs ± 3.7% (3.58 … 4.29 μs) 23.1
FPU GEMM Float16*Float16=Float32 (2048×2048) · (2048×2048) (TN) OP (8, 16, 2), base shape (4, 8, 1) 1.55 ms ± 2.85% (1.52 … 1.67 ms) 293.0 μs ± 0.294% (291.0 … 295.0 μs) 19.2
FPU GEMM Float64*Float64=Float64 (2048×2048) · (2048×2048) (NN) OP (8, 16, 2), base shape (4, 8, 1) 46.5 ms ± 0.695% (46.2 … 52.7 ms) 41.0 ms ± 0.208% (41.0 … 43.6 ms) 88.7
WMMA GEMM Float16*Float16+Float16=Float16 (256×512) · (512×256) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.1 μs ± 1.4% (11.9 … 12.4 μs) 5.25 μs ± 3.62% (4.77 … 5.72 μs) 40.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (4, 32, 1), base shape (2, 16, 1) 22.0 μs ± 0.634% (21.7 … 22.2 μs) 4.09 μs ± 2.43% (3.81 … 4.29 μs) 17.6
WMMA GEMM Float16*Float16+Float16=Float16 (256×512) · (512×256) (NT) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 12.0 μs ± 1.34% (11.7 … 12.4 μs) 4.6 μs ± 2.94% (4.29 … 5.01 μs) 36.7
FPU GEMM Float32*Float32=Float32 (256×256) · (256×256) (NT) OP (8, 16, 2), base shape (4, 8, 1) 18.2 μs ± 0.782% (17.9 … 18.6 μs) 6.75 μs ± 1.97% (6.44 … 6.91 μs) 36.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NN) OP (4, 16, 1), base shape (4, 8, 1) 15.2 μs ± 1.12% (15.0 … 15.5 μs) 4.1 μs ± 2.5% (3.81 … 4.53 μs) 25.4
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (NT) OP (4, 8, 1), base shape (4, 8, 1) 15.3 μs ± 1.11% (15.0 … 15.5 μs) 4.21 μs ± 3.78% (4.05 … 4.53 μs) 27.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (4, 32, 1), base shape (1, 32, 1) 33.0 μs ± 0.389% (32.7 … 33.4 μs) 4.3 μs ± 3.34% (4.05 … 4.77 μs) 12.4
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (8, 16, 2), base shape (4, 8, 1) 12.1 μs ± 1.31% (11.9 … 12.4 μs) 4.31 μs ± 3.26% (4.05 … 4.53 μs) 34.0
FPU GEMM Float64*Float64=Float64 (128×128) · (128×128) (TN) OP (8, 16, 2), base shape (4, 8, 1) 148.0 μs ± 0.367% (146.0 … 149.0 μs) 28.8 μs ± 0.566% (28.6 … 29.1 μs) 19.5
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (16, 16, 1), base shape (4, 8, 1) 15.2 μs ± 1.15% (15.0 … 15.5 μs) 4.31 μs ± 3.32% (4.05 … 4.53 μs) 27.0
WMMA GEMM Float16*Float16+Float16=Float16 (256×256) · (256×256) (TN) Block (128, 128, 64) Warps (4, 2) OP (16, 16, 16) 7.46 μs ± 1.58% (7.15 … 7.87 μs) 3.73 μs ± 3.85% (3.58 … 4.05 μs) 50.0
FPU GEMM Float32*Float32=Float32 (128×128) · (128×128) (TN) OP (16, 16, 1), base shape (8, 4, 1) 14.5 μs ± 1.04% (14.3 … 14.8 μs) 4.31 μs ± 3.27% (4.05 … 4.53 μs) 28.3
FPU GEMM Float32*Float32=Float32 (256×256) · (256×256) (TN) OP (8, 16, 2), base shape (4, 8, 1) 12.7 μs ± 1.4% (12.4 … 12.9 μs) 7.16 μs ± 1.31% (6.68 … 7.63 μs) 53.8

Copy link

codecov bot commented Dec 7, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (c84a5ac) 34.23% compared to head (7e244b0) 34.94%.
Report is 1 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #180      +/-   ##
==========================================
+ Coverage   34.23%   34.94%   +0.70%     
==========================================
  Files          11       11              
  Lines         923      933      +10     
==========================================
+ Hits          316      326      +10     
  Misses        607      607              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@thomasfaingnaert thomasfaingnaert merged commit 3052b52 into master Dec 7, 2023
1 check passed
@thomasfaingnaert thomasfaingnaert deleted the tf/check-tile-sizes branch December 7, 2023 22:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants