-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ship Transformer Engine in the JAX container #132
Conversation
# Transformer Engine installation dependencies | ||
pip install --no-cache-dir pybind11 ninja packaging | ||
# Install JAX + Transformer Engine | ||
NVTE_FRAMEWORK=jax pip --disable-pip-version-check --no-cache-dir install -e \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering what you think about setting ENV NVTE_FRAMEWORK=jax
. Since we must set that environment variable to correctly install TE, it seems more friendly for re-installs. For example, when installing t5x next, it'll look something like pip install -e /opt/transformer-engine -e /opt/t5x
and if NVTE_FRAMEWORK
is set in the base container, we can save a few characters when installing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The env variable isn't needed anymore if jax is already installed.
By default, the setup.py file try to import all the fw and will build for those installed.
…mber limit Will try to land #132 ASAP as a substitution.
Completed in #371. |
After merging this, we will no longer build new
ghcr.io/nvidia/jax-te
images. Instead,ghcr.io/nvidia/jax
will always ship with TE.