Results are as reported by this notebook. To re-run these experiments, just head over to Google Colab, upload the notebook, and run the cells one by one.
Hardware at the time of writing (Oct 2021):
- Intel(R) Xeon(R) CPU @ 2.30GHz (1 core, 2 threads)
- 12.6GB of RAM
- NVidia Tesla K80 GPU with 12GB memory
Caveat: Jax does not support 64bit floating point precision on TPU architectures (yet). Therefore, the Jax + TPU results are not bit-identical to all other backends and devices, so it's not really an apples-to-apples comparison.
An equation consisting of >100 terms with no data dependencies and only elementary math. This benchmark should represent a best-case scenario for vector instructions and GPU performance.
$ taskset -c 0 python run.py benchmarks/equation_of_state/
benchmarks.equation_of_state
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 pytorch 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.015 5.605
4,096 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.014 5.167
4,096 numba 10,000 0.001 0.000 0.000 0.000 0.001 0.001 0.013 3.178
4,096 aesara 10,000 0.001 0.000 0.000 0.001 0.001 0.001 0.015 2.637
4,096 tensorflow 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.009 2.143
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.010 1.000
16,384 pytorch 10,000 0.001 0.000 0.001 0.001 0.001 0.001 0.017 6.284
16,384 jax 10,000 0.002 0.000 0.001 0.001 0.002 0.002 0.019 5.396
16,384 tensorflow 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.005 4.161
16,384 numba 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.010 3.816
16,384 aesara 10,000 0.002 0.000 0.002 0.002 0.002 0.002 0.017 3.520
16,384 numpy 1,000 0.009 0.001 0.007 0.008 0.009 0.009 0.012 1.000
65,536 pytorch 1,000 0.005 0.001 0.005 0.005 0.005 0.006 0.015 16.182
65,536 jax 1,000 0.006 0.001 0.005 0.005 0.006 0.006 0.009 15.457
65,536 tensorflow 1,000 0.006 0.001 0.005 0.006 0.006 0.006 0.021 14.052
65,536 numba 1,000 0.009 0.001 0.008 0.008 0.009 0.009 0.017 10.105
65,536 aesara 1,000 0.009 0.001 0.008 0.009 0.009 0.009 0.015 9.394
65,536 numpy 100 0.088 0.003 0.079 0.086 0.088 0.090 0.097 1.000
262,144 pytorch 1,000 0.018 0.001 0.015 0.017 0.017 0.018 0.028 10.783
262,144 jax 1,000 0.020 0.002 0.017 0.019 0.019 0.020 0.035 9.667
262,144 tensorflow 1,000 0.021 0.001 0.018 0.020 0.021 0.022 0.031 8.949
262,144 numba 100 0.032 0.002 0.029 0.031 0.031 0.033 0.044 5.930
262,144 aesara 100 0.034 0.002 0.032 0.033 0.033 0.034 0.042 5.666
262,144 numpy 100 0.190 0.003 0.177 0.188 0.190 0.192 0.200 1.000
1,048,576 pytorch 100 0.075 0.003 0.068 0.073 0.074 0.077 0.083 21.187
1,048,576 jax 100 0.086 0.004 0.079 0.083 0.085 0.088 0.098 18.447
1,048,576 tensorflow 100 0.087 0.004 0.080 0.085 0.087 0.089 0.099 18.140
1,048,576 numba 100 0.132 0.004 0.125 0.129 0.132 0.134 0.145 11.976
1,048,576 aesara 100 0.140 0.004 0.131 0.137 0.140 0.142 0.157 11.301
1,048,576 numpy 10 1.585 0.015 1.568 1.573 1.579 1.595 1.612 1.000
4,194,304 pytorch 10 0.297 0.006 0.285 0.294 0.297 0.302 0.307 12.408
4,194,304 tensorflow 10 0.342 0.005 0.331 0.339 0.343 0.345 0.349 10.793
4,194,304 jax 10 0.360 0.008 0.348 0.354 0.357 0.367 0.373 10.253
4,194,304 numba 10 0.515 0.007 0.504 0.510 0.516 0.522 0.526 7.155
4,194,304 aesara 10 0.556 0.009 0.543 0.547 0.558 0.563 0.569 6.634
4,194,304 numpy 10 3.688 0.014 3.668 3.678 3.688 3.693 3.723 1.000
(time in wall seconds, less is better)
$ for backend in jax tensorflow pytorch cupy; do python run.py benchmarks/equation_of_state/ --device gpu -b $backend -b numpy; done
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 12.584
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.011 1.000
16,384 jax 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.003 61.389
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.017 1.000
65,536 jax 1,000 0.000 0.000 0.000 0.000 0.000 0.000 0.000 250.282
65,536 numpy 100 0.047 0.002 0.044 0.046 0.046 0.048 0.053 1.000
262,144 jax 1,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 699.509
262,144 numpy 100 0.309 0.011 0.256 0.304 0.310 0.314 0.336 1.000
1,048,576 jax 100 0.002 0.000 0.001 0.001 0.001 0.002 0.004 542.418
1,048,576 numpy 10 0.818 0.009 0.805 0.808 0.819 0.824 0.831 1.000
4,194,304 jax 100 0.006 0.001 0.005 0.005 0.005 0.005 0.012 544.624
4,194,304 numpy 10 3.153 0.014 3.123 3.152 3.156 3.163 3.173 1.000
(time in wall seconds, less is better)
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 tensorflow 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.006 3.709
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.012 1.000
16,384 tensorflow 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.006 18.838
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.014 1.000
65,536 tensorflow 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.006 513.398
65,536 numpy 100 0.228 0.008 0.203 0.224 0.227 0.233 0.256 1.000
262,144 tensorflow 1,000 0.000 0.000 0.000 0.000 0.000 0.000 0.004 747.237
262,144 numpy 100 0.343 0.012 0.274 0.338 0.343 0.350 0.372 1.000
1,048,576 tensorflow 1,000 0.001 0.000 0.000 0.000 0.000 0.001 0.006 1657.587
1,048,576 numpy 10 0.873 0.012 0.851 0.866 0.875 0.881 0.890 1.000
4,194,304 tensorflow 100 0.001 0.000 0.001 0.001 0.001 0.001 0.001 4226.591
4,194,304 numpy 10 3.175 0.014 3.153 3.164 3.175 3.183 3.197 1.000
(time in wall seconds, less is better)
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 pytorch 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.008 15.199
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.010 1.000
16,384 pytorch 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.008 69.659
16,384 numpy 1,000 0.009 0.001 0.007 0.008 0.009 0.009 0.016 1.000
65,536 pytorch 10,000 0.000 0.000 0.000 0.000 0.000 0.000 0.009 1393.452
65,536 numpy 100 0.286 0.088 0.126 0.151 0.331 0.338 0.397 1.000
262,144 pytorch 1,000 0.000 0.000 0.000 0.000 0.000 0.000 0.007 989.724
262,144 numpy 100 0.418 0.106 0.220 0.251 0.474 0.482 0.521 1.000
1,048,576 pytorch 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.010 716.353
1,048,576 numpy 10 0.970 0.201 0.721 0.728 1.101 1.144 1.160 1.000
4,194,304 pytorch 100 0.005 0.000 0.005 0.005 0.005 0.005 0.005 708.456
4,194,304 numpy 10 3.402 0.017 3.371 3.389 3.400 3.417 3.428 1.000
(time in wall seconds, less is better)
benchmarks.equation_of_state
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 10,000 0.002 0.000 0.001 0.002 0.002 0.002 0.005 1.000
4,096 cupy 1,000 0.007 0.002 0.005 0.006 0.006 0.009 0.018 0.223
16,384 cupy 1,000 0.008 0.002 0.006 0.006 0.007 0.009 0.020 1.085
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.011 1.000
65,536 cupy 1,000 0.008 0.002 0.006 0.006 0.007 0.009 0.017 5.290
65,536 numpy 100 0.040 0.002 0.038 0.039 0.040 0.041 0.046 1.000
262,144 cupy 1,000 0.016 0.001 0.015 0.015 0.015 0.017 0.019 9.686
262,144 numpy 100 0.154 0.003 0.148 0.152 0.154 0.157 0.166 1.000
1,048,576 cupy 100 0.058 0.004 0.053 0.054 0.054 0.060 0.065 12.664
1,048,576 numpy 10 0.728 0.012 0.710 0.725 0.726 0.733 0.753 1.000
4,194,304 cupy 10 0.208 0.009 0.203 0.203 0.204 0.207 0.233 14.708
4,194,304 numpy 10 3.062 0.014 3.039 3.053 3.066 3.073 3.083 1.000
(time in wall seconds, less is better)
$ python run.py benchmarks/equation_of_state -b jax -b numpy --device tpu
benchmarks.equation_of_state
============================
Running on TPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.002 0.001 0.001 0.002 0.002 0.003 0.007 1.044
4,096 numpy 10,000 0.002 0.001 0.002 0.002 0.002 0.003 0.021 1.000
16,384 jax 1,000 0.002 0.001 0.001 0.002 0.002 0.003 0.007 4.138
16,384 numpy 1,000 0.010 0.002 0.008 0.009 0.009 0.010 0.052 1.000
65,536 jax 1,000 0.002 0.001 0.002 0.002 0.002 0.003 0.007 56.663
65,536 numpy 100 0.139 0.009 0.101 0.137 0.140 0.144 0.158 1.000
262,144 jax 100 0.002 0.000 0.002 0.002 0.002 0.003 0.004 105.074
262,144 numpy 100 0.255 0.013 0.227 0.250 0.253 0.261 0.319 1.000
1,048,576 jax 100 0.003 0.001 0.002 0.003 0.003 0.003 0.008 359.453
1,048,576 numpy 10 1.075 0.025 1.041 1.057 1.069 1.085 1.125 1.000
4,194,304 jax 10 0.004 0.000 0.004 0.004 0.004 0.004 0.005 737.921
4,194,304 numpy 10 3.200 0.033 3.142 3.182 3.199 3.210 3.266 1.000
(time in wall seconds, less is better)
A more balanced routine with many data dependencies (stencil operations), and tensor shapes of up to 5 dimensions. This is the most expensive part of Veros, so in a way this is the benchmark that interests me the most.
$ taskset -c 0 python run.py benchmarks/isoneutral_mixing/
benchmarks.isoneutral_mixing
============================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.001 0.001 0.001 0.001 0.001 0.016 3.293
4,096 numba 1,000 0.002 0.002 0.001 0.001 0.001 0.001 0.050 2.904
4,096 aesara 1,000 0.003 0.003 0.002 0.003 0.003 0.003 0.059 1.334
4,096 numpy 1,000 0.004 0.002 0.004 0.004 0.004 0.004 0.063 1.000
4,096 pytorch 1,000 0.004 0.002 0.003 0.004 0.004 0.005 0.052 0.981
16,384 jax 1,000 0.006 0.001 0.005 0.006 0.006 0.006 0.021 2.664
16,384 numba 1,000 0.007 0.002 0.006 0.006 0.006 0.007 0.054 2.461
16,384 aesara 1,000 0.012 0.001 0.010 0.011 0.011 0.012 0.026 1.433
16,384 pytorch 1,000 0.012 0.003 0.010 0.011 0.011 0.012 0.061 1.424
16,384 numpy 1,000 0.017 0.002 0.015 0.016 0.016 0.017 0.043 1.000
65,536 jax 100 0.029 0.001 0.026 0.028 0.028 0.029 0.034 2.597
65,536 numba 100 0.030 0.003 0.026 0.028 0.029 0.030 0.050 2.494
65,536 pytorch 100 0.050 0.002 0.046 0.048 0.049 0.051 0.059 1.502
65,536 aesara 100 0.050 0.002 0.047 0.049 0.050 0.051 0.057 1.483
65,536 numpy 100 0.075 0.002 0.070 0.073 0.075 0.077 0.080 1.000
262,144 jax 10 0.111 0.004 0.105 0.108 0.111 0.114 0.118 2.408
262,144 numba 100 0.116 0.004 0.108 0.113 0.115 0.118 0.130 2.314
262,144 pytorch 10 0.178 0.004 0.173 0.176 0.178 0.179 0.184 1.503
262,144 aesara 10 0.190 0.004 0.183 0.187 0.190 0.194 0.197 1.408
262,144 numpy 10 0.268 0.009 0.254 0.262 0.267 0.274 0.285 1.000
1,048,576 numba 10 0.480 0.004 0.473 0.476 0.479 0.483 0.488 2.524
1,048,576 jax 10 0.599 0.007 0.592 0.593 0.597 0.604 0.615 2.020
1,048,576 aesara 10 0.834 0.011 0.816 0.828 0.833 0.835 0.862 1.451
1,048,576 pytorch 10 0.863 0.080 0.786 0.799 0.806 0.944 0.983 1.403
1,048,576 numpy 10 1.210 0.169 1.134 1.147 1.160 1.165 1.718 1.000
4,194,304 numba 10 1.947 0.011 1.926 1.939 1.953 1.956 1.958 2.739
4,194,304 jax 10 2.477 0.096 2.422 2.441 2.445 2.461 2.761 2.154
4,194,304 aesara 10 3.620 0.017 3.592 3.610 3.620 3.630 3.647 1.473
4,194,304 pytorch 10 3.668 0.026 3.631 3.658 3.663 3.675 3.730 1.454
4,194,304 numpy 10 5.334 0.042 5.271 5.297 5.333 5.374 5.388 1.000
(time in wall seconds, less is better)
$ for backend in jax pytorch cupy; do python run.py benchmarks/isoneutral_mixing/ --device gpu -b $backend -b numpy; done
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.009 4.187
4,096 numpy 1,000 0.004 0.001 0.004 0.004 0.004 0.004 0.013 1.000
16,384 jax 1,000 0.001 0.001 0.001 0.001 0.001 0.001 0.008 13.768
16,384 numpy 1,000 0.017 0.001 0.015 0.016 0.016 0.017 0.024 1.000
65,536 jax 100 0.003 0.000 0.003 0.003 0.004 0.004 0.004 21.820
65,536 numpy 100 0.075 0.004 0.070 0.073 0.074 0.076 0.094 1.000
262,144 jax 100 0.014 0.001 0.012 0.012 0.015 0.015 0.020 19.799
262,144 numpy 10 0.274 0.009 0.260 0.272 0.274 0.274 0.293 1.000
1,048,576 jax 10 0.057 0.005 0.052 0.052 0.054 0.062 0.063 21.834
1,048,576 numpy 10 1.239 0.009 1.226 1.231 1.237 1.246 1.254 1.000
4,194,304 jax 10 0.200 0.011 0.192 0.192 0.195 0.207 0.223 25.440
4,194,304 numpy 10 5.097 0.033 5.054 5.071 5.088 5.124 5.153 1.000
(time in wall seconds, less is better)
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.004 0.001 0.004 0.004 0.004 0.004 0.013 1.000
4,096 pytorch 1,000 0.006 0.001 0.005 0.005 0.005 0.007 0.014 0.746
16,384 pytorch 1,000 0.006 0.001 0.005 0.005 0.005 0.007 0.017 2.667
16,384 numpy 1,000 0.016 0.001 0.014 0.016 0.016 0.017 0.027 1.000
65,536 pytorch 100 0.007 0.001 0.006 0.007 0.007 0.008 0.015 12.932
65,536 numpy 100 0.097 0.007 0.080 0.094 0.097 0.100 0.125 1.000
262,144 pytorch 100 0.016 0.002 0.014 0.015 0.015 0.016 0.021 17.586
262,144 numpy 10 0.274 0.005 0.267 0.270 0.273 0.277 0.281 1.000
1,048,576 pytorch 10 0.051 0.003 0.048 0.050 0.050 0.050 0.060 25.531
1,048,576 numpy 10 1.292 0.011 1.276 1.284 1.292 1.296 1.316 1.000
4,194,304 pytorch 10 0.192 0.011 0.182 0.182 0.184 0.202 0.211 25.674
4,194,304 numpy 10 4.923 0.013 4.901 4.917 4.920 4.929 4.954 1.000
(time in wall seconds, less is better)
benchmarks.isoneutral_mixing
============================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.004 0.001 0.003 0.004 0.004 0.004 0.013 1.000
4,096 cupy 1,000 0.013 0.002 0.010 0.011 0.012 0.015 0.026 0.343
16,384 cupy 1,000 0.013 0.002 0.010 0.011 0.012 0.015 0.024 1.273
16,384 numpy 1,000 0.017 0.001 0.015 0.016 0.016 0.017 0.027 1.000
65,536 cupy 100 0.013 0.002 0.011 0.012 0.012 0.015 0.025 5.723
65,536 numpy 100 0.075 0.005 0.068 0.072 0.074 0.077 0.086 1.000
262,144 cupy 100 0.021 0.002 0.018 0.019 0.023 0.023 0.027 13.102
262,144 numpy 10 0.279 0.007 0.272 0.274 0.276 0.286 0.292 1.000
1,048,576 cupy 10 0.071 0.006 0.067 0.068 0.069 0.069 0.083 17.415
1,048,576 numpy 10 1.240 0.020 1.191 1.232 1.250 1.252 1.263 1.000
4,194,304 cupy 10 0.270 0.012 0.259 0.260 0.264 0.280 0.291 18.798
4,194,304 numpy 10 5.071 0.045 4.962 5.048 5.089 5.096 5.124 1.000
(time in wall seconds, less is better)
$ python run.py benchmarks/isoneutral_mixing -b jax -b numpy --device tpu
benchmarks.isoneutral_mixing
============================
Running on TPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 100 0.004 0.003 0.003 0.003 0.004 0.004 0.033 1.603
4,096 numpy 1,000 0.007 0.003 0.005 0.006 0.006 0.006 0.037 1.000
16,384 jax 100 0.005 0.005 0.003 0.004 0.004 0.005 0.041 4.733
16,384 numpy 100 0.024 0.006 0.020 0.021 0.022 0.023 0.065 1.000
65,536 jax 100 0.005 0.002 0.004 0.004 0.005 0.006 0.018 20.059
65,536 numpy 10 0.106 0.009 0.096 0.101 0.103 0.113 0.126 1.000
262,144 jax 10 0.007 0.001 0.006 0.006 0.006 0.007 0.009 68.206
262,144 numpy 10 0.458 0.034 0.364 0.460 0.470 0.473 0.490 1.000
1,048,576 jax 10 0.016 0.002 0.015 0.015 0.015 0.015 0.022 97.621
1,048,576 numpy 10 1.522 0.035 1.471 1.500 1.520 1.540 1.601 1.000
4,194,304 jax 10 0.056 0.009 0.050 0.050 0.051 0.060 0.073 109.384
4,194,304 numpy 10 6.156 0.077 6.071 6.089 6.138 6.195 6.306 1.000
(time in wall seconds, less is better)
This routine consists of some stencil operations and some linear algebra (a tridiagonal matrix solver), which cannot be vectorized.
$ taskset -c 0 python run.py benchmarks/turbulent_kinetic_energy/
benchmarks.turbulent_kinetic_energy
===================================
Running on CPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.000 0.000 0.000 0.000 0.001 0.004 4.918
4,096 numba 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.005 2.312
4,096 pytorch 1,000 0.002 0.001 0.001 0.002 0.002 0.002 0.008 1.227
4,096 numpy 1,000 0.003 0.001 0.002 0.002 0.002 0.003 0.009 1.000
16,384 jax 1,000 0.002 0.000 0.002 0.002 0.002 0.002 0.008 3.708
16,384 numba 1,000 0.004 0.001 0.003 0.003 0.004 0.004 0.009 2.265
16,384 pytorch 1,000 0.005 0.001 0.004 0.004 0.005 0.005 0.009 1.803
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.020 1.000
65,536 jax 100 0.009 0.000 0.008 0.009 0.009 0.009 0.012 4.015
65,536 numba 100 0.013 0.001 0.012 0.013 0.013 0.013 0.019 2.801
65,536 pytorch 100 0.018 0.001 0.016 0.017 0.018 0.019 0.024 2.047
65,536 numpy 100 0.038 0.001 0.035 0.036 0.037 0.038 0.044 1.000
262,144 jax 100 0.040 0.002 0.037 0.039 0.039 0.041 0.047 3.173
262,144 numba 100 0.046 0.003 0.042 0.044 0.045 0.047 0.057 2.745
262,144 pytorch 10 0.064 0.002 0.061 0.062 0.063 0.064 0.068 1.992
262,144 numpy 10 0.127 0.002 0.123 0.125 0.127 0.129 0.130 1.000
1,048,576 numba 10 0.187 0.003 0.183 0.185 0.187 0.189 0.191 3.046
1,048,576 jax 10 0.237 0.003 0.232 0.235 0.236 0.238 0.241 2.408
1,048,576 pytorch 10 0.297 0.005 0.289 0.294 0.296 0.302 0.304 1.918
1,048,576 numpy 10 0.570 0.007 0.559 0.564 0.569 0.577 0.579 1.000
4,194,304 numba 10 0.737 0.010 0.721 0.730 0.739 0.743 0.751 3.447
4,194,304 jax 10 1.212 0.012 1.193 1.204 1.210 1.220 1.232 2.097
4,194,304 pytorch 10 1.404 0.006 1.395 1.400 1.403 1.410 1.415 1.809
4,194,304 numpy 10 2.540 0.014 2.519 2.529 2.545 2.549 2.557 1.000
(time in wall seconds, less is better)
$ for backend in jax pytorch; do python run.py benchmarks/turbulent_kinetic_energy/ --device gpu -b $backend -b numpy; done
benchmarks.turbulent_kinetic_energy
===================================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.004 2.625
4,096 numpy 1,000 0.002 0.000 0.002 0.002 0.002 0.003 0.006 1.000
16,384 jax 1,000 0.001 0.000 0.001 0.001 0.001 0.001 0.003 6.924
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.013 1.000
65,536 jax 100 0.002 0.000 0.002 0.002 0.003 0.003 0.003 15.079
65,536 numpy 100 0.038 0.002 0.035 0.036 0.037 0.038 0.047 1.000
262,144 jax 100 0.010 0.001 0.009 0.010 0.011 0.011 0.011 12.195
262,144 numpy 10 0.128 0.003 0.123 0.127 0.128 0.129 0.132 1.000
1,048,576 jax 10 0.043 0.003 0.040 0.041 0.043 0.046 0.046 12.451
1,048,576 numpy 10 0.540 0.006 0.525 0.538 0.541 0.544 0.545 1.000
4,194,304 jax 10 0.111 0.008 0.099 0.105 0.111 0.119 0.120 20.741
4,194,304 numpy 10 2.309 0.008 2.296 2.303 2.309 2.316 2.320 1.000
(time in wall seconds, less is better)
benchmarks.turbulent_kinetic_energy
===================================
Running on GPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 numpy 1,000 0.003 0.000 0.002 0.002 0.002 0.003 0.006 1.000
4,096 pytorch 1,000 0.003 0.001 0.003 0.003 0.003 0.004 0.007 0.790
16,384 pytorch 1,000 0.004 0.001 0.003 0.003 0.003 0.004 0.008 2.273
16,384 numpy 1,000 0.008 0.001 0.007 0.008 0.008 0.009 0.012 1.000
65,536 pytorch 100 0.005 0.001 0.004 0.004 0.005 0.005 0.008 8.471
65,536 numpy 100 0.039 0.002 0.036 0.038 0.039 0.040 0.045 1.000
262,144 pytorch 100 0.008 0.001 0.007 0.007 0.008 0.008 0.011 16.245
262,144 numpy 10 0.126 0.002 0.123 0.124 0.126 0.128 0.132 1.000
1,048,576 pytorch 10 0.027 0.002 0.025 0.025 0.027 0.027 0.031 20.552
1,048,576 numpy 10 0.549 0.008 0.540 0.545 0.548 0.553 0.567 1.000
4,194,304 pytorch 10 0.108 0.008 0.096 0.101 0.108 0.114 0.123 21.209
4,194,304 numpy 10 2.290 0.008 2.277 2.286 2.289 2.295 2.302 1.000
(time in wall seconds, less is better)
$ python run.py benchmarks/turbulent_kinetic_energy -b jax -b numpy --device tpu
benchmarks.turbulent_kinetic_energy
===================================
Running on TPU
size backend calls mean stdev min 25% median 75% max Δ
------------------------------------------------------------------------------------------------------------------
4,096 jax 100 0.003 0.001 0.002 0.003 0.003 0.003 0.015 1.132
4,096 numpy 1,000 0.004 0.001 0.003 0.003 0.003 0.004 0.035 1.000
16,384 jax 100 0.003 0.001 0.003 0.003 0.003 0.004 0.007 3.322
16,384 numpy 1,000 0.011 0.002 0.010 0.010 0.011 0.011 0.041 1.000
65,536 jax 100 0.004 0.004 0.003 0.003 0.004 0.004 0.031 11.957
65,536 numpy 100 0.050 0.004 0.045 0.048 0.050 0.051 0.065 1.000
262,144 jax 10 0.004 0.000 0.004 0.004 0.004 0.004 0.005 40.486
262,144 numpy 10 0.178 0.005 0.168 0.176 0.177 0.182 0.185 1.000
1,048,576 jax 10 0.008 0.000 0.008 0.008 0.008 0.008 0.009 95.165
1,048,576 numpy 10 0.803 0.040 0.750 0.766 0.797 0.834 0.872 1.000
4,194,304 jax 10 0.022 0.000 0.022 0.022 0.022 0.022 0.023 121.268
4,194,304 numpy 10 2.679 0.349 2.423 2.482 2.511 2.745 3.577 1.000
(time in wall seconds, less is better)