Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Compiler Optimization]: A pattern containing the reciprocal of the sqrt is not optimized to rsqrt automatically #507

Closed
winnylyc opened this issue Jan 18, 2024 · 2 comments · Fixed by #508
Assignees

Comments

@winnylyc
Copy link
Contributor

Issue Type

Performance

Modules Involved

SPU compiler

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

spu 0.7.0b0

OS Platform and Distribution

Linux Ubuntu 22.04

Python Version

3.10

Compiler Version

No response

Current Behavior?

Hello, sorry for disturbing you.

When I implement normalizer in SPU, I counter a problem that a pattern containing the reciprocal of the sqrt is not optimized to rsqrt automatically. I want to ask here whether this kind of pattern can be supported in the future and whether this automatic optimization may have side effect.

Standalone code to reproduce the issue

# The original code containing the reciprocal of the sqrt
X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]])
norms = jnp.einsum("ij,ij->i", X, X)
X / jnp.sqrt(norms)[:, jnp.newaxis]

# The manually optimized code with the same result but better speed
X = jnp.array([[4, 1, 2, 2], [1, 3, 9, 3], [5, 7, 5, 1]])
norms = jnp.einsum("ij,ij->i", X, X)
norms = norms.astype(jnp.float32)
X * jax.lax.rsqrt(norms)[:, jnp.newaxis]

Relevant log output

# Here shows the performance when using the following emulator
# emulator = emulation.Emulator(
#             emulation.CLUSTER_ABY3_3PC,
#             emulation.Mode.MULTIPROCESS,
#             bandwidth=300,
#             latency=20,
#         )
# The profile for original code containing the reciprocal of the sqrt
[2024-01-18 05:03:44.755] [info] [api.cc:191] HLO profiling: total time 0.025788818
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.broadcast, executed 1 times, duration 4.766e-06s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.convert, executed 2 times, duration 0.000128667s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.divide, executed 1 times, duration 0.0106705s, send bytes 8256
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.dot_general, executed 1 times, duration 0.00157601s, send bytes 24
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.free, executed 7 times, duration 5.65e-06s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.reshape, executed 1 times, duration 1.2534e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.sqrt, executed 1 times, duration 0.013370251s, send bytes 2124
[2024-01-18 05:03:44.755] [info] [api.cc:194] - pphlo.transpose, executed 1 times, duration 2.044e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:191] HAL profiling: total time 0.025575716
[2024-01-18 05:03:44.755] [info] [api.cc:194] - f_div, executed 1 times, duration 0.01066082s, send bytes 8256
[2024-01-18 05:03:44.755] [info] [api.cc:194] - f_sqrt, executed 1 times, duration 0.013347743s, send bytes 2124
[2024-01-18 05:03:44.755] [info] [api.cc:194] - i_mmul, executed 3 times, duration 0.001476084s, send bytes 24
[2024-01-18 05:03:44.755] [info] [api.cc:194] - int2fxp, executed 2 times, duration 9.1069e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:191] MPC profiling: total time 0.024049148
[2024-01-18 05:03:44.755] [info] [api.cc:194] - a2b, executed 2 times, duration 0.002759383s, send bytes 1680
[2024-01-18 05:03:44.755] [info] [api.cc:194] - add_aa, executed 7 times, duration 3.3803e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - add_ap, executed 15 times, duration 7.9238e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - add_pp, executed 2 times, duration 1.198e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - and_bb, executed 12 times, duration 0.001702799s, send bytes 660
[2024-01-18 05:03:44.755] [info] [api.cc:194] - and_bp, executed 18 times, duration 6.0062e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - b2a, executed 6 times, duration 0.00797263s, send bytes 3720
[2024-01-18 05:03:44.755] [info] [api.cc:194] - bitrev_b, executed 3 times, duration 2.998e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - lshift_a, executed 2 times, duration 4.4846e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - lshift_b, executed 6 times, duration 1.6841e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - make_p, executed 26 times, duration 8.4075e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - mmul_aa, executed 3 times, duration 0.001446647s, send bytes 24
[2024-01-18 05:03:44.755] [info] [api.cc:194] - msb_a2b, executed 1 times, duration 0.001168064s, send bytes 432
[2024-01-18 05:03:44.755] [info] [api.cc:194] - mul_aa, executed 19 times, duration 0.003576891s, send bytes 1104
[2024-01-18 05:03:44.755] [info] [api.cc:194] - mul_ap, executed 7 times, duration 3.9717e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - not_a, executed 7 times, duration 4.5822e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - not_p, executed 1 times, duration 7.786e-06s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - rshift_b, executed 25 times, duration 7.0549e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - trunc_a, executed 18 times, duration 0.004743633s, send bytes 2784
[2024-01-18 05:03:44.755] [info] [api.cc:194] - xor_bb, executed 42 times, duration 0.000132476s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:194] - xor_bp, executed 2 times, duration 2.1926e-05s, send bytes 0
[2024-01-18 05:03:44.755] [info] [api.cc:204] Link details: total send bytes 10404, send actions 132

# The profile for manually optimized code with the same result but better speed
[2024-01-18 04:43:48.192] [info] [api.cc:191] HLO profiling: total time 0.015152912000000001
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.broadcast, executed 1 times, duration 2.0069e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.convert, executed 1 times, duration 1.403e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.dot_general, executed 1 times, duration 0.002450025s, send bytes 24
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.free, executed 6 times, duration 4.104e-06s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.multiply, executed 1 times, duration 0.000124317s, send bytes 96
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.reshape, executed 1 times, duration 2.6426e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.rsqrt, executed 1 times, duration 0.012495452s, send bytes 1716
[2024-01-18 04:43:48.192] [info] [api.cc:194] - pphlo.transpose, executed 1 times, duration 1.8489e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:191] HAL profiling: total time 0.015002936
[2024-01-18 04:43:48.192] [info] [api.cc:194] - f_rsqrt, executed 1 times, duration 0.012484559s, send bytes 1716
[2024-01-18 04:43:48.192] [info] [api.cc:194] - i_mmul, executed 3 times, duration 0.002388416s, send bytes 24
[2024-01-18 04:43:48.192] [info] [api.cc:194] - int2fxp, executed 1 times, duration 9.48e-06s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - mixed_mul, executed 1 times, duration 0.000120481s, send bytes 96
[2024-01-18 04:43:48.192] [info] [api.cc:191] MPC profiling: total time 0.014330171999999999
[2024-01-18 04:43:48.192] [info] [api.cc:194] - a2b, executed 1 times, duration 0.00212471s, send bytes 336
[2024-01-18 04:43:48.192] [info] [api.cc:194] - add_aa, executed 3 times, duration 5.268e-06s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - add_ap, executed 2 times, duration 1.2276e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - add_pp, executed 2 times, duration 1.1417e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - and_bb, executed 6 times, duration 0.001191588s, send bytes 132
[2024-01-18 04:43:48.192] [info] [api.cc:194] - and_bp, executed 18 times, duration 7.8069e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - b2a, executed 3 times, duration 0.004090321s, send bytes 768
[2024-01-18 04:43:48.192] [info] [api.cc:194] - bitrev_b, executed 2 times, duration 9.278e-06s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - lshift_a, executed 1 times, duration 4.684e-06s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - lshift_b, executed 6 times, duration 1.8486e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - make_p, executed 16 times, duration 4.2967e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - mmul_aa, executed 3 times, duration 0.002357849s, send bytes 24
[2024-01-18 04:43:48.192] [info] [api.cc:194] - mul_aa, executed 7 times, duration 0.001272991s, send bytes 240
[2024-01-18 04:43:48.192] [info] [api.cc:194] - mul_ap, executed 5 times, duration 1.5179e-05s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - not_p, executed 1 times, duration 7.508e-06s, send bytes 0
[2024-01-18 04:43:48.192] [info] [api.cc:194] - rshift_b, executed 18 times, duration 4.3379e-05s, send bytes 0
[2024-01-18 04:43:48.193] [info] [api.cc:194] - trunc_a, executed 6 times, duration 0.002954356s, send bytes 336
[2024-01-18 04:43:48.193] [info] [api.cc:194] - xor_bb, executed 29 times, duration 7.3501e-05s, send bytes 0
[2024-01-18 04:43:48.193] [info] [api.cc:194] - xor_bp, executed 1 times, duration 1.6345e-05s, send bytes 0
[2024-01-18 04:43:48.193] [info] [api.cc:204] Link details: total send bytes 1836, send actions 56
@anakinxc anakinxc self-assigned this Jan 18, 2024
@anakinxc
Copy link
Contributor

Hi @winnylyc

Thanks for reporting this. Yes, this can be optimized by our compiler stack. Will try to add this in next release.

Best
~Yancheng

@anakinxc anakinxc changed the title [Bug]: A pattern containing the reciprocal of the sqrt is not optimized to rsqrt automatically [Compiler Optimization]: A pattern containing the reciprocal of the sqrt is not optimized to rsqrt automatically Jan 18, 2024
@winnylyc
Copy link
Contributor Author

Thanks for your response!

@winnylyc winnylyc reopened this Jan 18, 2024
@anakinxc anakinxc mentioned this issue Jan 19, 2024
anakinxc added a commit that referenced this issue Jan 19, 2024
# Pull Request

## What problem does this PR solve?

Issue Number: Fixed #507, partial resolved #387 

## Possible side effects?

- Performance:

- Backward compatibility:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants