Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Installing JAX with gpu/tpu support using poetry #5516

Closed
2 tasks done
pablo2909 opened this issue Apr 28, 2022 · 8 comments · Fixed by #5517
Closed
2 tasks done

Installing JAX with gpu/tpu support using poetry #5516

pablo2909 opened this issue Apr 28, 2022 · 8 comments · Fixed by #5517

Comments

@pablo2909
Copy link

pablo2909 commented Apr 28, 2022

  • I have searched the issues of this repo and believe that this is not a duplicate.
  • I have searched the documentation and believe that my question is not covered.

Issue

Hi everyone,

I am installing JAX using poetry. I run the command poetry add jax and it works fine but this installs the cpu version, as expected. To install the gpu/tpu version of JAX the documentation indicates that I have to run:

pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

I understand I could run this specific command in my environment, but if I do this, I believe, it is not handled well by poetry. Is there a more poetry way of installing JAX for gpu/tpu ?

Thank you for the help :)

@abn
Copy link
Member

abn commented Apr 28, 2022

Poetry only supports PyPI and PEP 503 simple API repositories (nexus, gitlab packages, artifactory etc). An example similar to jax is pytorch. The latter recently started providing a PEP 503 repository making things easier.

Note that this still has annoyances see #4231 (comment).

Supporting "find-links" style sources is something that will probably happen in the future. However, note that poetry will have to download all files listed to fetch hashes and inspect metadata.

All that said, in the mean time you provide url dependencies pointing to the wheels you want. Ideally with platform markers?

Hope this helps.

@abn
Copy link
Member

abn commented Apr 29, 2022

Can you please try the fix at #5517?

Using pipx

pipx install --suffix=@5517 'poetry @ git+https://github.com/python-poetry/poetry.git@refs/pull/5517/head'

Using a container (podman | docker)

podman run --rm -i --entrypoint bash docker.io/python:3.10 <<EOF
set -xe
python -m pip install -q git+https://github.com/python-poetry/poetry.git@refs/pull/5517/head
poetry new foobar
pushd foobar
poetry source add jax https://storage.googleapis.com/jax-releases/jax_releases.html
poetry add -vvv --lock --source jax jaxlib
cat poetry.lock
EOF

@pablo2909
Copy link
Author

Thank you for writing that !

I'm not sure how pipx works. If I run this command without a container, what would happen to my installed version of poetry ?
Thank you for the help

@abn
Copy link
Member

abn commented Apr 29, 2022

The --suffix=@5517 would mean that this instance of poetry will be available as poetry@5517 and won't affect your default poetry version as this would be isolated in its own environment (this is what pipx does).

@pablo2909
Copy link
Author

here is the .toml file I have:

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = 
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
jaxlib = {version = "^0.3.7+cuda11.cudnn82", source = "jax"}



[[tool.poetry.source]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
default = false
secondary = false

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

and .lock

[[package]]
name = "absl-py"
version = "1.0.0"
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
category = "main"
optional = false
python-versions = ">=3.6"

[package.dependencies]
six = "*"

[[package]]
name = "flatbuffers"
version = "2.0"
description = "The FlatBuffers serialization format for Python"
category = "main"
optional = false
python-versions = "*"

[[package]]
name = "jaxlib"
version = "0.3.7+cuda11.cudnn82"
description = "XLA library for JAX"
category = "main"
optional = false
python-versions = ">=3.7"

[package.dependencies]
absl-py = "*"
flatbuffers = ">=1.12,<3.0"
numpy = ">=1.19"
scipy = "*"

[package.source]
type = "legacy"
url = "https://storage.googleapis.com/jax-releases/jax_releases.html"
reference = "jax"

[[package]]
name = "numpy"
version = "1.22.3"
description = "NumPy is the fundamental package for array computing with Python."
category = "main"
optional = false
python-versions = ">=3.8"

[[package]]
name = "scipy"
version = "1.6.1"
description = "SciPy: Scientific Library for Python"
category = "main"
optional = false
python-versions = ">=3.7"

[package.dependencies]
numpy = ">=1.16.5"

[[package]]
name = "six"
version = "1.16.0"
description = "Python 2 and 3 compatibility utilities"
category = "main"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"

[metadata]
lock-version = "1.1"
python-versions = "^3.10"
content-hash = "9f10c226d6b941232791e19862e24808f4136300ba691a1f8505220b8c36ab57"

[metadata.files]
absl-py = [
    {file = "absl-py-1.0.0.tar.gz", hash = "sha256:ac511215c01ee9ae47b19716599e8ccfa746f2e18de72bdf641b79b22afa27ea"},
    {file = "absl_py-1.0.0-py3-none-any.whl", hash = "sha256:84e6dcdc69c947d0c13e5457d056bd43cade4c2393dce00d684aedea77ddc2a3"},
]
flatbuffers = [
    {file = "flatbuffers-2.0-py2.py3-none-any.whl", hash = "sha256:3751954f0604580d3219ae49a85fafec9d85eec599c0b96226e1bc0b48e57474"},
    {file = "flatbuffers-2.0.tar.gz", hash = "sha256:12158ab0272375eab8db2d663ae97370c33f152b27801fa6024e1d6105fd4dd2"},
]
jaxlib = [
    {file = "jaxlib-0.3.7+cuda11.cudnn82-cp310-none-manylinux2014_x86_64.whl", hash = "sha256:1d7e540071bad5a76a2ad8a2b6c0dd075adaabe4bab7fb6e116f04ff5425fe1b"},
    {file = "jaxlib-0.3.7+cuda11.cudnn82-cp37-none-manylinux2014_x86_64.whl", hash = "sha256:8ce56ccf18fd79c476910251875e7f0f73417d4ec4912b29b2066d9ff8d82997"},
    {file = "jaxlib-0.3.7+cuda11.cudnn82-cp38-none-manylinux2014_x86_64.whl", hash = "sha256:f6076884c5d1bbf55c2fb153454afb118beeedcd85189793217c82ecb234fc8c"},
    {file = "jaxlib-0.3.7+cuda11.cudnn82-cp39-none-manylinux2014_x86_64.whl", hash = "sha256:16a87c125f0075d62995b18eba449e962b45db010435e6f0a65ee701378fc75f"},
]
numpy = [
    {file = "numpy-1.22.3-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:92bfa69cfbdf7dfc3040978ad09a48091143cffb778ec3b03fa170c494118d75"},
    {file = "numpy-1.22.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8251ed96f38b47b4295b1ae51631de7ffa8260b5b087808ef09a39a9d66c97ab"},
    {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48a3aecd3b997bf452a2dedb11f4e79bc5bfd21a1d4cc760e703c31d57c84b3e"},
    {file = "numpy-1.22.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3bae1a2ed00e90b3ba5f7bd0a7c7999b55d609e0c54ceb2b076a25e345fa9f4"},
    {file = "numpy-1.22.3-cp310-cp310-win32.whl", hash = "sha256:f950f8845b480cffe522913d35567e29dd381b0dc7e4ce6a4a9f9156417d2430"},
    {file = "numpy-1.22.3-cp310-cp310-win_amd64.whl", hash = "sha256:08d9b008d0156c70dc392bb3ab3abb6e7a711383c3247b410b39962263576cd4"},
    {file = "numpy-1.22.3-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:201b4d0552831f7250a08d3b38de0d989d6f6e4658b709a02a73c524ccc6ffce"},
    {file = "numpy-1.22.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f8c1f39caad2c896bc0018f699882b345b2a63708008be29b1f355ebf6f933fe"},
    {file = "numpy-1.22.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:568dfd16224abddafb1cbcce2ff14f522abe037268514dd7e42c6776a1c3f8e5"},
    {file = "numpy-1.22.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ca688e1b9b95d80250bca34b11a05e389b1420d00e87a0d12dc45f131f704a1"},
    {file = "numpy-1.22.3-cp38-cp38-win32.whl", hash = "sha256:e7927a589df200c5e23c57970bafbd0cd322459aa7b1ff73b7c2e84d6e3eae62"},
    {file = "numpy-1.22.3-cp38-cp38-win_amd64.whl", hash = "sha256:07a8c89a04997625236c5ecb7afe35a02af3896c8aa01890a849913a2309c676"},
    {file = "numpy-1.22.3-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:2c10a93606e0b4b95c9b04b77dc349b398fdfbda382d2a39ba5a822f669a0123"},
    {file = "numpy-1.22.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fade0d4f4d292b6f39951b6836d7a3c7ef5b2347f3c420cd9820a1d90d794802"},
    {file = "numpy-1.22.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bfb1bb598e8229c2d5d48db1860bcf4311337864ea3efdbe1171fb0c5da515d"},
    {file = "numpy-1.22.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97098b95aa4e418529099c26558eeb8486e66bd1e53a6b606d684d0c3616b168"},
    {file = "numpy-1.22.3-cp39-cp39-win32.whl", hash = "sha256:fdf3c08bce27132395d3c3ba1503cac12e17282358cb4bddc25cc46b0aca07aa"},
    {file = "numpy-1.22.3-cp39-cp39-win_amd64.whl", hash = "sha256:639b54cdf6aa4f82fe37ebf70401bbb74b8508fddcf4797f9fe59615b8c5813a"},
    {file = "numpy-1.22.3-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c34ea7e9d13a70bf2ab64a2532fe149a9aced424cd05a2c4ba662fd989e3e45f"},
    {file = "numpy-1.22.3.zip", hash = "sha256:dbc7601a3b7472d559dc7b933b18b4b66f9aa7452c120e87dfb33d02008c8a18"},
]
scipy = [
    {file = "scipy-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:a15a1f3fc0abff33e792d6049161b7795909b40b97c6cc2934ed54384017ab76"},
    {file = "scipy-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:e79570979ccdc3d165456dd62041d9556fb9733b86b4b6d818af7a0afc15f092"},
    {file = "scipy-1.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:a423533c55fec61456dedee7b6ee7dce0bb6bfa395424ea374d25afa262be261"},
    {file = "scipy-1.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:33d6b7df40d197bdd3049d64e8e680227151673465e5d85723b3b8f6b15a6ced"},
    {file = "scipy-1.6.1-cp37-cp37m-win32.whl", hash = "sha256:6725e3fbb47da428794f243864f2297462e9ee448297c93ed1dcbc44335feb78"},
    {file = "scipy-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:5fa9c6530b1661f1370bcd332a1e62ca7881785cc0f80c0d559b636567fab63c"},
    {file = "scipy-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bd50daf727f7c195e26f27467c85ce653d41df4358a25b32434a50d8870fc519"},
    {file = "scipy-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:f46dd15335e8a320b0fb4685f58b7471702234cba8bb3442b69a3e1dc329c345"},
    {file = "scipy-1.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:0e5b0ccf63155d90da576edd2768b66fb276446c371b73841e3503be1d63fb5d"},
    {file = "scipy-1.6.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:2481efbb3740977e3c831edfd0bd9867be26387cacf24eb5e366a6a374d3d00d"},
    {file = "scipy-1.6.1-cp38-cp38-win32.whl", hash = "sha256:68cb4c424112cd4be886b4d979c5497fba190714085f46b8ae67a5e4416c32b4"},
    {file = "scipy-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:5f331eeed0297232d2e6eea51b54e8278ed8bb10b099f69c44e2558c090d06bf"},
    {file = "scipy-1.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8a51d33556bf70367452d4d601d1742c0e806cd0194785914daf19775f0e67"},
    {file = "scipy-1.6.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:83bf7c16245c15bc58ee76c5418e46ea1811edcc2e2b03041b804e46084ab627"},
    {file = "scipy-1.6.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:794e768cc5f779736593046c9714e0f3a5940bc6dcc1dba885ad64cbfb28e9f0"},
    {file = "scipy-1.6.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:5da5471aed911fe7e52b86bf9ea32fb55ae93e2f0fac66c32e58897cfb02fa07"},
    {file = "scipy-1.6.1-cp39-cp39-win32.whl", hash = "sha256:8e403a337749ed40af60e537cc4d4c03febddcc56cd26e774c9b1b600a70d3e4"},
    {file = "scipy-1.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:a5193a098ae9f29af283dcf0041f762601faf2e595c0db1da929875b7570353f"},
    {file = "scipy-1.6.1.tar.gz", hash = "sha256:c4fceb864890b6168e79b0e714c585dbe2fd4222768ee90bc1aa0f8218691b11"},
]
six = [
    {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"},
    {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
]

@abn
Copy link
Member

abn commented Apr 29, 2022

Thanks for testing it out @pablo2909; looks like the change is working as intended; the local version selection is another issue (there is an open issue #4729).

@pablo2909
Copy link
Author

Nice, thank you for your help. Looking forward for this feature to be merged !

Copy link

github-actions bot commented Mar 2, 2024

This issue has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Mar 2, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants