Skip to content

Commit

Permalink
Add Dockerfile and entrypoint script for the jetstream-pytorch-server…
Browse files Browse the repository at this point in the history
… image
  • Loading branch information
vivianrwu committed Dec 10, 2024
1 parent a0449f7 commit a1b9cd9
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 1 deletion.
52 changes: 52 additions & 0 deletions docker/jetstream-pytorch-server/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Ubuntu:22.04
# Use Ubuntu 22.04 from Docker Hub.
# https://hub.docker.com/_/ubuntu/tags?page=1&name=22.04
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV PYTORCH_JETSTREAM_VERSION=main

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
git \
python3.10 \
python3-pip

RUN python3 -m pip install --upgrade pip

RUN update-alternatives --install \
/usr/bin/python3 python3 /usr/bin/python3.10 1


RUN git clone https://github.com/AI-Hypercomputer/JetStream.git
RUN git clone https://github.com/AI-Hypercomputer/jetstream-pytorch.git && \
cd /jetstream-pytorch && \
git checkout ${PYTORCH_JETSTREAM_VERSION} && \
bash install_everything.sh

RUN pip install -U jax[tpu]==0.4.35 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

RUN cd /JetStream && \
pip install -e .

RUN pip install huggingface_hub[cli]

COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/

RUN chmod +x /usr/bin/jetstream_pytorch_server_entrypoint.sh

ENTRYPOINT ["/usr/bin/jetstream_pytorch_server_entrypoint.sh"]
14 changes: 14 additions & 0 deletions docker/jetstream-pytorch-server/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
## Build and upload JetStream PyTorch Server image

These instructions are to build the JetStream PyTorch Server image, which calls an entrypoint script that invokes the [JetStream](https://github.com/AI-Hypercomputer/JetStream) inference server with the JetStream-PyTorch framework.

```
docker build -t jetstream-pytorch-server .
docker tag jetstream-pytorch-server us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest
docker push us-docker.pkg.dev/${PROJECT_ID}/jetstream/jetstream-pytorch-server:latest
```

If you would like to change the version of MaxText the image is built off of, change the `PYTORCH_JETSTREAM_VERSION` environment variable:
```
ENV PYTORCH_JETSTREAM_VERSION=<your desired commit hash, release, or tag>
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

#!/bin/bash
export HUGGINGFACE_TOKEN_DIR="/huggingface"

cd /jetstream-pytorch
huggingface-cli login --token $(cat ${HUGGINGFACE_TOKEN_DIR}/HUGGINGFACE_TOKEN)
jpt serve $@
4 changes: 3 additions & 1 deletion jetstream_pt/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def serve():
if 1 <= FLAGS.prometheus_port <= 65535:
metrics_server_config = MetricsServerConfig(port=FLAGS.prometheus_port)
else:
raise ValueError(f"Invalid port number: {FLAGS.prometheus_port}. Port must be between 1 and 65535.")
raise ValueError(
f"Invalid port number: {FLAGS.prometheus_port}. Port must be between 1 and 65535."
)

# We separate credential from run so that we can unit test it with local credentials.
# We would like to add grpc credentials for OSS.
Expand Down

0 comments on commit a1b9cd9

Please sign in to comment.