Skip to content

Commit

Permalink
parent abb6f97
Browse files Browse the repository at this point in the history
author Yu-Hang Tang <Tang.Maxin@gmail.com> 1698050497 +0000
committer Terry Kong <terryk@nvidia.com> 1701417045 -0800

pip-compile changes

Updated t5-large perf (#342)

Update Pax README and sub file (#345)

- Adds FP8 documentation
- Updates perf table
- Makes some other minor improvements for readability

Adds CUDA_MODULE_LOADING=EAGER to core jax container env vars (#329)

Re-enable NVLS in nightly containers (#331)

NVLS was disabled due to a known issue in NCCL 2.17 that caused
intermittent hangs. The issue has been resolved in NCCL 2.18, so we are
safe to re-enable NVLS.

---------

Co-authored-by: Terry Kong <terryk@nvidia.com>

Update Pax TE patch to point to rebased branch (#348)

Loosens t5x loss tests relative tolerances (#343)

Relaxing the relative tolerance on the loss tests since it was leading
to too many false positives. For reference, deviation in loss for the t5
model can sometimes be up to 15% at the start of training with real
data.

Adds rosetta-t5x TE + no-TE tests that enable the correct configs for testing (#332)

- [ ] Add capability to retroactively test with newer test-t5x.sh like
in
[t5x-wget-test](https://github.com/NVIDIA/JAX-Toolbox/tree/t5x-wget-test)
- [ ] Sets `ENABLE_TE=1` in the Dockerfile.t5x which is identical to the
logic from before where it was always enabled in rosetta-t5x

Fix markdown hyperlink for jax package on frontpage readme (#319)

Adds a --seed option to test-t5x.sh to ensure determinism (#344)

To ensure that the tests results for a particular container are
reproducible between runs, this change introduces a seed argument that
sets the jax seed and dataset seed to 42. It remains configurable, but
now there shouldn't be variance given the same container.

- Also fixes a typo where --steps-per-epoch wasn't in the usage doc of
this script

Co-authored-by: NVIDIA <jax@nvidia.com>
Co-authored-by: Yu-Hang "Maxin" Tang <Tang.Maxin@gmail.com>

Dynamic workflow run names (#356)

This change introduces the dynamic [run name
field](https://github.blog/changelog/2022-09-26-github-actions-dynamic-names-for-workflow-runs/#:~:text=GitHub%20Actions%20customers%20can%20now,visit%20the%20GitHub%20Actions%20community.)
`run-name`.

It's currently difficult on mobile to find the "workflow_run" that
corresponds to a particular date, so hopefully this helps identify which
builds were nightly vs which builds were manually triggered.

I couldn't find a good way to dynamically look up the `name` field, so
for now I copied all of names. I also wasn't able to find a "created_at"
for the scheduled workflows, so those don't have timestamps for now.

__Assumptions__:
* "workflow_run" == nightly since "scheduled" events only happen on
`main` and `workflow_run` are only run for concrete workflows and not
reusable workflows

- [x] Test the workflow_run codepath
- [x] Test the scheduled codepath

![image](https://github.com/NVIDIA/JAX-Toolbox/assets/7576060/4b916452-334a-4a73-9220-9fbadc70462f)

Fix random failling tests for backend_independent on V100 (#351)

Fixes randomly failures in the backend-independent section of JAX unit
tests:
```
Cannot find a free accelerator to run the test  on, exiting with failure
```

Changes: limit the number of concurrent test jobs even for
backend-independent tests, which do create GPU contexts.

As a clarification, `--jobs` and `--local_test_jobs` do not make a
difference for our particular CI pipeline, since JAX is built in a
separate CI job anyway.

References (From Reed Wanderman-Milne @ Google):

> 1. In particular, you have to set NB_GPUS, JOBS_PER_ACC, and J
correctly or you can get that error (I recently got the same error by
not setting those correctly)
> 2. (also I think --jobs should be --local_test_jobs in that code
block, no reason to restrict the number of jobs compiling JAX)

Propagate error code in ViT tests (#357)

Merges rosetta unit tests and takes off the marker which spun up another matrix job (#360)

This should simplify the rosetta tests and save some time since another
matrix job was started for one test

Propagate build failures (#363)

Always run the `publish-build` step, regardless of whether the rosetta
pax/t5x build was attempted. This ensures that badges correctly reflect
build failures due to dependent builds failing.

Patch for JAX core container (ARM64) (#367)

Add patch to XLA to be able to build JAX core container for ARM64

Update the doc for USE_FP8 (#349)

This PR provides guidance on how to use the new configuration option,
`USE_FP8`, to enable native FP8 support on Hopper GPUs.

Update the native-fp8 guide with cudnn layer norm (#368)

This PR updates the guide to include the new flag to enable the cudnn
layer norm.

cc. @ashors1 @terrykong @nouiz

Add WAR for XLA NCCL bug causing OOMs (#362)

A stopgap for #346

fix TE multi-device test

fix lzma build issue

edit TE test name

fix TE arm64 test install error

remove --install option from get-source.sh

fix TE arm64 test install error

disable sandbox

i'm jet-lagged

use Pax image for TE testing

Fix job dependency
  • Loading branch information
yhtang authored and terrykong committed Dec 8, 2023
1 parent fefb8de commit 2ddaa95
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 9 deletions.
43 changes: 37 additions & 6 deletions .github/container/Dockerfile.jax
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ RUN --mount=type=ssh \
--mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \
git clone "${REPO_XLA}" "${SRC_PATH_XLA}" && cd "${SRC_PATH_XLA}" && git checkout ${REF_XLA}

# TODO: This is a WAR to NCCL errors we observe in TOT. Should be removed when no longer needed
RUN <<EOF bash -ex
cd ${SRC_PATH_XLA}

git config user.name "JAX Toolbox"
git config user.email "jax@nvidia.com"
git remote add -f ashors1 https://github.com/ashors1/xla
git cherry-pick --allow-empty $(git merge-base ashors/main ashors1/revert-84222)..ashors1/revert-84222
git remote remove ashors1
EOF

ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
ADD xla-arm64-neon.patch /opt
RUN build-jax.sh \
Expand Down Expand Up @@ -68,15 +79,35 @@ COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

COPY --from=jax-builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
COPY --from=jax-builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/

RUN mkdir -p /opt/pip-tools.d
RUN <<EOF
echo "-e ${SRC_PATH_JAX}" >> /opt/pip-tools.d/requirements-jax.in
echo "$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/requirements-jax.in
echo "flax" >> /opt/pip-tools.d/requirements-jax.in
RUN <<"EOF" bash -ex
echo "-e file://${SRC_PATH_JAX}" >> /opt/pip-tools.d/manifest.jax
echo "jaxlib @ file://$(ls ${SRC_PATH_JAX}/dist/*.whl)" >> /opt/pip-tools.d/manifest.jax
EOF

## Flax
ARG REPO_FLAX
ARG REF_FLAX
ARG SRC_PATH_FLAX
RUN get-source.sh -f ${REPO_FLAX} -r ${REF_FLAX} -d ${SRC_PATH_FLAX} -m /opt/pip-tools.d/manifest.flax

## Transformer engine: check out source and build wheel
ARG REPO_TE
ARG REF_TE
ARG SRC_PATH_TE
ENV NVTE_FRAMEWORK=jax
ENV SRC_PATH_TE=${SRC_PATH_TE}
RUN <<"EOF" bash -ex
set -o pipefail
pip install ninja && rm -rf ~/.cache/pip
get-source.sh -f ${REPO_TE} -r ${REF_TE} -d ${SRC_PATH_TE}
pushd ${SRC_PATH_TE}
python setup.py bdist_wheel && rm -rf build
echo "transformer-engine @ file://$(ls ${SRC_PATH_TE}/dist/*.whl)" >> /opt/pip-tools.d/manifest.te
EOF

# TODO: properly configure entrypoint
Expand Down
8 changes: 7 additions & 1 deletion .github/workflows/_sandbox.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
name: "~Sandbox"

on:
workflow_dispatch:
# workflow_dispatch:
# push:

permissions:
contents: read # to fetch code
actions: write # to cancel previous workflows
packages: write # to upload container

jobs:
sandbox:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/nightly-t5x-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ jobs:
runs-on: ubuntu-22.04
outputs:
DOCKER_TAG_MEALKIT: ''
DOCKER_TAG_FINAL: ''
steps:
- name: Generate placeholder warning
shell: bash -x -e {0}
Expand Down
1 change: 0 additions & 1 deletion rosetta/rosetta/projects/t5x/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ t5x/contrib/gpu/scripts_gpu/singlenode_ft_frompile.sh \

# Known Issues
* There is a known sporadic NCCL crash that happens when using the T5x container at node counts greater than or equal to 32 nodes. We will fix this in the next release. The issue is tracked [here](https://github.com/NVIDIA/JAX-Toolbox/issues/194).
* The T5x nightlies disable `NCCL_NVLS_ENABLE=0` ([doc](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature.

# Changelog
- Added Transformer Engine + FP8 support
Expand Down
1 change: 0 additions & 1 deletion rosetta/rosetta/projects/vit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,4 @@ Pre-training was performed on 1 node with a global batch size of 4096. Models we

## Known Issues
1. By default, gradient accumulation (GA) sums loss across the microbatches. As a result, loss is scaled up when using gradient accumulation, and training with GA only works when using a scale-invariant optimizer such as Adam or Adafactor. ViT fine-tuning is performed using SGD; thus, GA should not be used when fine-tuning.
2. The nightlies disable `NCCL_NVLS_ENABLE=0` ([doc](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature.

0 comments on commit 2ddaa95

Please sign in to comment.