Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vmap isn't working when using JAX-based MPSCircuits #214

Open
Muzhou-Ma opened this issue May 24, 2024 · 3 comments
Open

Vmap isn't working when using JAX-based MPSCircuits #214

Muzhou-Ma opened this issue May 24, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Muzhou-Ma
Copy link

Issue Description

In the test of #213, I found that the Vmap function isn't working when using JAX-based MPSCircuit. The program is not paralleled and only uses one CPU core. The bug seems to be caused by MPSCircuit, when using JAX-based ordinary circuit, everything is fine.

@Muzhou-Ma Muzhou-Ma added the bug Something isn't working label May 24, 2024
@refraction-ray
Copy link
Contributor

refraction-ray commented May 25, 2024

can be reproduced, may be due to the same issue for QR and SVD. These operations might not support vmap.

Updated: nope, jax can vmap qr and svd, the reason of vmap failure in MPSCircuit requires further investigation

@refraction-ray
Copy link
Contributor

tf backend vmap is ok but with very low CPU utilization, only around 150% for my test example

@Muzhou-Ma
Copy link
Author

Muzhou-Ma commented May 25, 2024

tf backend vmap is ok but with very low CPU utilization, only around 150% for my test example

tf backend seems to have a warning with QR decomposition, perhaps this will cause low CPU utilization.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants