Skip to content

Commit

Permalink
Bump Jetstream, maxtext, jetstream-pytorch versions in Jetstream infe…
Browse files Browse the repository at this point in the history
…rence server guide (#695)

* initial commit, working branch

* bump to official maxtext version

* revert server changes

* remove 'slabe' images

* revert checkpoint conversion changes

* Remove jetstream install
  • Loading branch information
Bslabe123 authored Jun 4, 2024
1 parent db8a568 commit 3096a0e
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ class GrpcBenchmarkUser(GrpcUser):
def grpc_infer(self):
prompt = get_random_prompt(self)
request = jetstream_pb2.DecodeRequest(
additional_text=prompt,
text_content=jetstream_pb2.DecodeRequest.TextContent(text=request.prompt),
priority=0,
max_tokens=model_params["max_output_len"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ convert_maxtext_checkpoint() {
MAXTEXT_VERSION=$5

if [ -z $MAXTEXT_VERSION ]; then
MAXTEXT_VERSION=jetstream-v0.2.0
MAXTEXT_VERSION=jetstream-v0.2.2
fi

git clone https://github.com/google/maxtext.git
Expand All @@ -77,10 +77,10 @@ convert_pytorch_checkpoint() {
OUTPUT_CKPT_DIR=$3
QUANTIZE=$4
PYTORCH_VERSION=$5
JETSTREAM_VERSION=v0.2.0
JETSTREAM_VERSION=v0.2.2

if [ -z $PYTORCH_VERSION ]; then
PYTORCH_VERSION=jetstream-v0.2.0
PYTORCH_VERSION=jetstream-v0.2.2
fi

CKPT_PATH="$(echo ${INPUT_CKPT_DIR} | awk -F'gs://' '{print $2}')"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV JETSTREAM_VERSION=v0.2.0
ENV JETSTREAM_VERSION=v0.2.2

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ async def generate(request: GenerateRequest):
try:
request = jetstream_pb2.DecodeRequest(
session_cache=request.session_cache,
additional_text=request.prompt,
text_content=jetstream_pb2.DecodeRequest.TextContent(text=request.prompt),
priority=request.priority,
max_tokens=request.max_tokens,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV MAXTEXT_VERSION=jetstream-v0.2.0
ENV JETSTREAM_VERSION=v0.2.0
ENV MAXTEXT_VERSION=jetstream-v0.2.2

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
Expand All @@ -16,17 +15,12 @@ RUN apt -y update && apt install -y --no-install-recommends \
RUN update-alternatives --install \
/usr/bin/python3 python3 /usr/bin/python3.10 1

RUN git clone https://github.com/google/maxtext.git && \
git clone https://github.com/google/JetStream.git
RUN git clone https://github.com/google/maxtext.git

RUN cd maxtext/ && \
git checkout ${MAXTEXT_VERSION} && \
bash setup.sh

RUN cd /JetStream && \
git checkout ${JETSTREAM_VERSION} && \
pip install -e .

COPY maxengine_server_entrypoint.sh /usr/bin/

RUN chmod +x /usr/bin/maxengine_server_entrypoint.sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ spec:
restartPolicy: Never
containers:
- name: inference-checkpoint
image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.0
image: us-docker.pkg.dev/cloud-tpu-images/inference/inference-checkpoint:v0.2.2
args:
- -b=BUCKET_NAME
- -m=google/gemma/maxtext/7b-it/2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ spec:
containers:
- name: maxengine-server
image: us-docker.pkg.dev/cloud-tpu-images/inference/maxengine-server:v0.2.0
imagePullPolicy: Always
securityContext:
privileged: true
args:
Expand All @@ -33,6 +34,7 @@ spec:
- scan_layers=false
- weight_dtype=bfloat16
- load_parameters_path=gs://BUCKET_NAME/final/unscanned/gemma_7b-it/0/checkpoints/0/items
- prometheus_port=9100
ports:
- containerPort: 9000
resources:
Expand All @@ -41,7 +43,8 @@ spec:
limits:
google.com/tpu: 8
- name: jetstream-http
image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.0
image: us-docker.pkg.dev/cloud-tpu-images/inference/jetstream-http:v0.2.2
imagePullPolicy: Always
ports:
- containerPort: 8000
---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
FROM ubuntu:22.04

ENV DEBIAN_FRONTEND=noninteractive
ENV PYTORCH_JETSTREAM_VERSION=jetstream-v0.2.0
ENV JETSTREAM_VERSION=v0.2.0
ENV PYTORCH_JETSTREAM_VERSION=jetstream-v0.2.2

RUN apt -y update && apt install -y --no-install-recommends \
ca-certificates \
Expand All @@ -21,11 +20,6 @@ cd /jetstream-pytorch && \
git checkout ${PYTORCH_JETSTREAM_VERSION} && \
bash install_everything.sh

RUN git clone https://github.com/google/JetStream.git && \
cd /JetStream && \
git checkout ${JETSTREAM_VERSION} && \
pip install -e .

ENV PYTHONPATH=$PYTHONPATH:$(pwd)/deps/xla/experimental/torch_xla2:$(pwd)/JetStream:$(pwd)

COPY jetstream_pytorch_server_entrypoint.sh /usr/bin/
Expand Down

0 comments on commit 3096a0e

Please sign in to comment.