Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance of halut matmul_online #18

Open
leomem opened this issue May 20, 2024 · 2 comments
Open

Performance of halut matmul_online #18

leomem opened this issue May 20, 2024 · 2 comments

Comments

@leomem
Copy link

leomem commented May 20, 2024

Hi, I am testing the example python code on an Intel Xeon box. Basically, np.matmul(A_test, B) and hm.matmul_online(A_test) are both executed 1000 times to compare the time difference. I suppose halutmatmul should be much faster. However, it turned out that
halutmatmul took much longer.
Total time taken to np matmul 1000 times: 0.05877375602722168 seconds
Total time taken to halut matmul 1000 times: 1.6328861713409424 seconds

Is there anything I am missing? Thanks!

@joennlae
Copy link
Owner

Hi :-) Thank you for the question.

I get your thinking :-)

np.matmul

np.matmul is a highly optimised coroutine that each hardware manufacturer provides SIMD libraries called (BLAS). These are then linked to numpy.

The linking in numpy happens around here:
https://github.com/numpy/numpy/blob/2970735a38b1a1142ab7fd0a14b906611448277e/numpy/_core/src/common/npy_cblas_base.h#L406

Reference to the sgemm documentation of the MKL BLAS library used on your Xeon box:
https://www.intel.com/content/www/us/en/docs/onemkl/developer-reference-c/2023-0/cblas-gemm-001.html

halutmatmul

I did some (very simple) optimization in python for halutmatmul:

With numba:

@numba.jit(parallel=True, nopython=True)
def read_luts_opt(
A_raveled: np.ndarray,
A_shape: tuple[int, int],
B_luts: np.ndarray,
total_result: np.ndarray,
) -> np.ndarray:
for i in prange((len(B_luts))):
read_lut = B_luts[i].ravel()[A_raveled].reshape(A_shape)
read_lut = read_lut.sum(axis=-1)
total_result[i] = read_lut
return total_result

This is done just in time. So if you run one warmup of hm.matmul_online to run the jit compilation, then run it 1000 times for the timing. It should already be faster.

But in the end, you will probably not beat the BLAS implementation in terms of speed. That is why we argue for very simple custom hardware support (see paper).

I hope this helps :-)

@leomem
Copy link
Author

leomem commented May 21, 2024

Thanks for the information. Very useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants