Skip to content

[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering. #673

[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering.

[Jax][Pallas][Mosaic] Implement platform dependent diag, with branch selection driven by constant prop in mosaic lowering. #673

# Cloud TPU CI (presubmit)
#
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
# tested to get to a stable state before we enable it as a blocking presubmit.
name: CI - Cloud TPU (presubmit)
on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true
jobs:
cloud-tpu-test:
if: github.event.repository.fork == false
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
tpu: [
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]
name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"
env:
JAXCI_PYTHON: python${{ matrix.python-version }}
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 60
defaults:
run:
shell: bash -ex {0}
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: openxla/xla
path: xla
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
- name: Mark GitHub workspace as safe
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
--python_version=${{ matrix.python-version }} \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose
# Install libtpu
$JAXCI_PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Install jaxlib wheel and run tests
run: ./ci/run_pytest_tpu.sh