diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml new file mode 100644 index 0000000000..89df7c5c23 --- /dev/null +++ b/.github/pytorch-probot.yml @@ -0,0 +1,13 @@ +tracking_issue: 2724 + +# List of workflows that will be re-run in case of failures +# https://github.com/pytorch/test-infra/blob/main/torchci/lib/bot/retryBot.ts +retryable_workflows: +- Run Regression Tests on Docker +- Run Regression Tests for CPU nightly binaries +- Push torchserve nightly +- Push Docker Nightly +- Docker CI +- CI CPU +- CI GPU +- Benchmark torchserve nightly diff --git a/.github/workflows/kserve_cpu_tests.yml b/.github/workflows/kserve_cpu_tests.yml new file mode 100644 index 0000000000..beb91945e2 --- /dev/null +++ b/.github/workflows/kserve_cpu_tests.yml @@ -0,0 +1,40 @@ +name: KServe CPU Nightly Tests + +on: + workflow_dispatch: + # runs everyday at 5:15am + schedule: + - cron: '15 5 * * *' + +jobs: + kserve-cpu-tests: + runs-on: [self-hosted, regression-test-gpu] + steps: + - name: Clean up previous run + run: | + echo "Cleaning up previous run" + ls -la ./ + sudo rm -rf ./* || true + sudo rm -rf ./.??* || true + ls -la ./ + - name: Install minikube and kubectl + run: | + curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube-linux-amd64 + sudo install minikube-linux-amd64 /usr/local/bin/minikube + curl -LO "https://dl.k8s.io/release/$(curl -L -s https://dl.k8s.io/release/stable.txt)/bin/linux/amd64/kubectl" + sudo install -o root -g root -m 0755 kubectl /usr/local/bin/kubectl + echo "/usr/local/bin" >> $GITHUB_PATH + - name: Setup Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: 3.8 + architecture: x64 + - name: Checkout TorchServe + uses: actions/checkout@v3 + - name: Checkout kserve repo + uses: actions/checkout@v4 + with: + repository: kserve/kserve + path: kserve + - name: Validate torchserve-kfs + run: ./kubernetes/kserve/tests/scripts/test_mnist.sh diff --git a/.github/workflows/regression_tests_cpu_binaries.yml b/.github/workflows/regression_tests_cpu_binaries.yml index d5ad0878a2..858cdeaff3 100644 --- a/.github/workflows/regression_tests_cpu_binaries.yml +++ b/.github/workflows/regression_tests_cpu_binaries.yml @@ -39,6 +39,6 @@ jobs: - name: Install dependencies run: | python ts_scripts/install_dependencies.py --environment=dev - - name: Torchserve Regression Tests + - name: Validate Torchserve CPU Regression run: | python test/regression_tests.py --binaries --${{ matrix.binaries }} --nightly diff --git a/.github/workflows/regression_tests_docker.yml b/.github/workflows/regression_tests_docker.yml index b861fadecd..97b1fd7320 100644 --- a/.github/workflows/regression_tests_docker.yml +++ b/.github/workflows/regression_tests_docker.yml @@ -26,7 +26,7 @@ jobs: sudo rm -rf ./* || true sudo rm -rf ./.??* || true ls -la ./ - docker system prune -f + docker system prune --all --volumes -f - name: Checkout TorchServe uses: actions/checkout@v3 - name: Branch name @@ -42,11 +42,11 @@ jobs: run: | cd docker ./build_image.sh -g -cv cu121 -bt ci -n -b $GITHUB_REF_NAME -t pytorch/torchserve:ci - - name: Torchserve GPU Regression Tests + - name: Validate Torchserve CPU Regression if: false == contains(matrix.hardware, 'ubuntu') run: | docker run --gpus all -v $GITHUB_WORKSPACE:/home/serve pytorch/torchserve:ci - - name: Torchserve CPU Regression Tests + - name: Validate Torchserve CPU Regression if: contains(matrix.hardware, 'ubuntu') run: | docker run -v $GITHUB_WORKSPACE:/home/serve pytorch/torchserve:ci diff --git a/README.md b/README.md index 76cd0100ee..c72b1a4320 100644 --- a/README.md +++ b/README.md @@ -55,19 +55,29 @@ docker pull pytorch/torchserve-nightly Refer to [torchserve docker](docker/README.md) for details. ## ⚡ Why TorchServe +* Write once, run anywhere, on-prem, on-cloud, supports inference on CPUs, GPUs, AWS Inf1/Inf2/Trn1, Google Cloud TPUs, [Nvidia MPS](master/docs/nvidia_mps.md) * [Model Management API](docs/management_api.md): multi model management with optimized worker to model allocation * [Inference API](docs/inference_api.md): REST and gRPC support for batched inference * [TorchServe Workflows](examples/Workflows/README.md): deploy complex DAGs with multiple interdependent models * Default way to serve PyTorch models in - * [Kubeflow](https://v0-5.kubeflow.org/docs/components/pytorchserving/) - * [MLflow](https://github.com/mlflow/mlflow-torchserve) * [Sagemaker](https://aws.amazon.com/blogs/machine-learning/serving-pytorch-models-in-production-with-the-amazon-sagemaker-native-torchserve-integration/) - * [Kserve](https://kserve.github.io/website/0.8/modelserving/v1beta1/torchserve/): Supports both v1 and v2 API * [Vertex AI](https://cloud.google.com/blog/topics/developers-practitioners/pytorch-google-cloud-how-deploy-pytorch-models-vertex-ai) -* Export your model for optimized inference. Torchscript out of the box, [ORT and ONNX](https://github.com/pytorch/serve/blob/master/docs/performance_guide.md), [IPEX](https://github.com/pytorch/serve/tree/master/examples/intel_extension_for_pytorch), [TensorRT](https://github.com/pytorch/serve/blob/master/docs/performance_guide.md), [FasterTransformer](https://github.com/pytorch/serve/tree/master/examples/FasterTransformer_HuggingFace_Bert) + * [Kubernetes](master/kubernetes) with support for [autoscaling](kubernetes#session-affinity-with-multiple-torchserve-pods), session-affinity, monitoring using Grafana works on-prem, AWS EKS, Google GKE, Azure AKS + * [Kserve](https://kserve.github.io/website/0.8/modelserving/v1beta1/torchserve/): Supports both v1 and v2 API, [autoscaling and canary deployments](kubernetes/kserve/README.md#autoscaling) for A/B testing + * [Kubeflow](https://v0-5.kubeflow.org/docs/components/pytorchserving/) + * [MLflow](https://github.com/mlflow/mlflow-torchserve) +* Export your model for optimized inference. Torchscript out of the box, [PyTorch Compiler](examples/pt2/README.md) preview, [ORT and ONNX](https://github.com/pytorch/serve/blob/master/docs/performance_guide.md), [IPEX](https://github.com/pytorch/serve/tree/master/examples/intel_extension_for_pytorch), [TensorRT](https://github.com/pytorch/serve/blob/master/docs/performance_guide.md), [FasterTransformer](https://github.com/pytorch/serve/tree/master/examples/FasterTransformer_HuggingFace_Bert), FlashAttention (Better Transformers) * [Performance Guide](docs/performance_guide.md): builtin support to optimize, benchmark and profile PyTorch and TorchServe performance * [Expressive handlers](CONTRIBUTING.md): An expressive handler architecture that makes it trivial to support inferencing for your usecase with [many supported out of the box](https://github.com/pytorch/serve/tree/master/ts/torch_handler) -* [Metrics API](docs/metrics.md): out of box support for system level metrics with [Prometheus exports](https://github.com/pytorch/serve/tree/master/examples/custom_metrics), custom metrics and PyTorch profiler support +* [Metrics API](docs/metrics.md): out of box support for system level metrics with [Prometheus exports](https://github.com/pytorch/serve/tree/master/examples/custom_metrics), custom metrics, +* [Large Model Inference Guide](docs/large_model_inference.md): With support for GenAI, LLMs including + * Fast Kernels with FlashAttention v2, continuous batching and streaming response + * PyTorch [Tensor Parallel](examples/large_models/tp_llama) preview, [Pipeline Parallel](examples/large_models/Huggingface_pippy) + * Microsoft [DeepSpeed](examples/large_models/deepspeed), [DeepSpeed-Mii](examples/large_models/deepspeed_mii) + * Hugging Face [Accelerate](large_models/Huggingface_accelerate), [Diffusers](examples/diffusers) + * Running large models on AWS [Sagemaker](https://docs.aws.amazon.com/sagemaker/latest/dg/large-model-inference-tutorials-torchserve.html) and [Inferentia2](https://pytorch.org/blog/high-performance-llama/) + * Running [Llama 2 Chatbot locally on Mac](examples/LLM/llama2) +* Monitoring using Grafana and [Datadog](https://www.datadoghq.com/blog/ai-integrations/#model-serving-and-deployment-vertex-ai-amazon-sagemaker-torchserve) ## 🤔 How does TorchServe work @@ -80,6 +90,7 @@ Refer to [torchserve docker](docker/README.md) for details. * [Serving Llama 2 with TorchServe](examples/LLM/llama2/README.md) * [Chatbot with Llama 2 on Mac 🦙💬](examples/LLM/llama2/chat_app) * [🤗 HuggingFace Transformers](examples/Huggingface_Transformers) with a [Better Transformer Integration/ Flash Attention & Xformer Memory Efficient ](examples/Huggingface_Transformers#Speed-up-inference-with-Better-Transformer) +* [Stable Diffusion](examples/diffusers) * [Model parallel inference](examples/Huggingface_Transformers#model-parallelism) * [MultiModal models with MMF](https://github.com/pytorch/serve/tree/master/examples/MMF-activity-recognition) combining text, audio and video * [Dual Neural Machine Translation](examples/Workflows/nmt_transformers_pipeline) for a complex workflow DAG @@ -100,6 +111,12 @@ We welcome all contributions! To learn more about how to contribute, see the contributor guide [here](https://github.com/pytorch/serve/blob/master/CONTRIBUTING.md). ## 📰 News +* [High performance Llama 2 deployments with AWS Inferentia2 using TorchServe](https://pytorch.org/blog/high-performance-llama/) +* [Naver Case Study: Transition From High-Cost GPUs to Intel CPUs and oneAPI powered Software with performance](https://pytorch.org/blog/ml-model-server-resource-saving/) +* [Run multiple generative AI models on GPU using Amazon SageMaker multi-model endpoints with TorchServe and save up to 75% in inference costs](https://aws.amazon.com/blogs/machine-learning/run-multiple-generative-ai-models-on-gpu-using-amazon-sagemaker-multi-model-endpoints-with-torchserve-and-save-up-to-75-in-inference-costs/) +* [Deploying your Generative AI model in only four steps with Vertex AI and PyTorch](https://cloud.google.com/blog/products/ai-machine-learning/get-your-genai-model-going-in-four-easy-steps) +* [PyTorch Model Serving on Google Cloud TPU v5](https://cloud.google.com/tpu/docs/v5e-inference#pytorch-model-inference-and-serving) +* [Monitoring using Datadog](https://www.datadoghq.com/blog/ai-integrations/#model-serving-and-deployment-vertex-ai-amazon-sagemaker-torchserve) * [Torchserve Performance Tuning, Animated Drawings Case-Study](https://pytorch.org/blog/torchserve-performance-tuning/) * [Walmart Search: Serving Models at a Scale on TorchServe](https://medium.com/walmartglobaltech/search-model-serving-using-pytorch-and-torchserve-6caf9d1c5f4d) * [🎥 Scaling inference on CPU with TorchServe](https://www.youtube.com/watch?v=066_Jd6cwZg) diff --git a/SECURITY.md b/SECURITY.md index 38d22373c6..1f424bcfa3 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -3,8 +3,8 @@ ## Supported Versions | Version | Supported | -| ------- | ------------------ | -| 0.8.2 | :white_check_mark: | +|---------| ------------------ | +| 0.9.0 | :white_check_mark: | ## How we do security diff --git a/benchmarks/config.properties b/benchmarks/config.properties index 5d819a29c6..a1b672d2c2 100644 --- a/benchmarks/config.properties +++ b/benchmarks/config.properties @@ -1,5 +1,5 @@ -inference_address=http://0.0.0.0:8080 -management_address=http://0.0.0.0:8081 +inference_address=http://127.0.0.1:8080 +management_address=http://127.0.0.1:8081 number_of_netty_threads=32 job_queue_size=1000 diff --git a/benchmarks/config_template.properties b/benchmarks/config_template.properties index 1b1e9772dd..c2be608e54 100644 --- a/benchmarks/config_template.properties +++ b/benchmarks/config_template.properties @@ -1,2 +1,2 @@ -inference_address=http://0.0.0.0:8080 -management_address=http://0.0.0.0:8081 +inference_address=http://127.0.0.1:8080 +management_address=http://127.0.0.1:8081 diff --git a/docker/build_upload_release.py b/docker/build_upload_release.py index 8def7bb217..44c3812297 100644 --- a/docker/build_upload_release.py +++ b/docker/build_upload_release.py @@ -56,7 +56,7 @@ f"{organization}/torchserve:{check_ts_version()}-cpu", f"{organization}/torchserve:{check_ts_version()}-gpu", ]: - os.system(f"docker push {image}") + try_and_handle(f"docker push {image}", dry_run) # Cleanup built images if args.cleanup: diff --git a/docs/batch_inference_with_ts.md b/docs/batch_inference_with_ts.md index b4f339d5a1..3ff04be63b 100644 --- a/docs/batch_inference_with_ts.md +++ b/docs/batch_inference_with_ts.md @@ -166,11 +166,11 @@ curl http://localhost:8081/models/resnet-152-batch_v2 ```text $ curl http://localhost:8080/predictions/resnet-152-batch_v2 -T kitten.jpg { - "tiger_cat": 0.5848360657691956, - "tabby": 0.3782736361026764, - "Egyptian_cat": 0.03441936895251274, - "lynx": 0.0005633446853607893, - "quilt": 0.0002698268508538604 + "tiger_cat": 0.5798614621162415, + "tabby": 0.38344162702560425, + "Egyptian_cat": 0.0342114195227623, + "lynx": 0.0005819813231937587, + "quilt": 0.000273319921689108 } ``` ### Batch inference of Resnet-152 configured through config.properties @@ -249,11 +249,11 @@ curl http://localhost:8081/models/resnet-152-batch_v2 ```text $ curl http://localhost:8080/predictions/resnet-152-batch_v2 -T kitten.jpg { - "tiger_cat": 0.5848360657691956, - "tabby": 0.3782736361026764, - "Egyptian_cat": 0.03441936895251274, - "lynx": 0.0005633446853607893, - "quilt": 0.0002698268508538604 + "tiger_cat": 0.5798614621162415, + "tabby": 0.38344162702560425, + "Egyptian_cat": 0.0342114195227623, + "lynx": 0.0005819813231937587, + "quilt": 0.000273319921689108 } ``` ## Demo to configure TorchServe ResNet-152 model with batch-supported model using Docker @@ -339,10 +339,10 @@ curl http://localhost:8081/models/resnet-152-batch_v2 ```text $ curl http://localhost:8080/predictions/resnet-152-batch_v2 -T kitten.jpg { - "tiger_cat": 0.5848360657691956, - "tabby": 0.3782736361026764, - "Egyptian_cat": 0.03441936895251274, - "lynx": 0.0005633446853607893, - "quilt": 0.0002698268508538604 + "tiger_cat": 0.5798614621162415, + "tabby": 0.38344162702560425, + "Egyptian_cat": 0.0342114195227623, + "lynx": 0.0005819813231937587, + "quilt": 0.000273319921689108 } ``` diff --git a/docs/index.rst b/docs/index.rst index f16037417e..06a36018fc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -9,6 +9,12 @@ TorchServe is a performant, flexible and easy to use tool for serving PyTorch mo What's going on in TorchServe? +* `High performance Llama 2 deployments with AWS Inferentia2 using TorchServe `__ +* `Naver Case Study: Transition From High-Cost GPUs to Intel CPUs and oneAPI powered Software with performance `__ +* `Run multiple generative AI models on GPU using Amazon SageMaker multi-model endpoints with TorchServe and save up to 75% in inference costs `__ +* `Deploying your Generative AI model in only four steps with Vertex AI and PyTorch `__ +* `PyTorch Model Serving on Google Cloud TPUv5 `__ +* `Monitoring using Datadog `__ * `Torchserve Performance Tuning, Animated Drawings Case-Study `__ * `Walmart Search: Serving Models at a Scale on TorchServe `__ * `Scaling inference on CPU with TorchServe `__ diff --git a/examples/large_models/tp_llama/REAME.md b/examples/large_models/tp_llama/README.md similarity index 100% rename from examples/large_models/tp_llama/REAME.md rename to examples/large_models/tp_llama/README.md diff --git a/examples/pt2/README.md b/examples/pt2/README.md index dbffc749ec..0758b089af 100644 --- a/examples/pt2/README.md +++ b/examples/pt2/README.md @@ -46,6 +46,17 @@ opt_mod = torch.compile(mod) torchserve takes care of 4 and 5 for you while the remaining steps are your responsibility. You can do the exact same thing on the vast majority of TIMM or HuggingFace models. +### Note + +`torch.compile()` is a JIT compiler and JIT compilers generally have a startup cost. If that's an issue for you make sure to populate these two environment variables to improve your warm starts. + +``` +import os + +os.environ["TORCHINDUCTOR_CACHE_DIR"] = "1" +os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "/path/to/directory" # replace with your desired path +``` + ## torch.export.export Export your model from a training script, keep in mind that an exported model cannot have graph breaks. diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java index 57348de638..1f89f4a48a 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/codec/ModelRequestEncoder.java @@ -76,7 +76,7 @@ private void encodeRequest(RequestInput req, ByteBuf out) { out.writeInt(buf.length); out.writeBytes(buf); - if (req.isCached()) { + if (req.isCachedInBackend()) { out.writeInt(-1); // End of List out.writeInt(-1); // End of List return; @@ -92,7 +92,6 @@ private void encodeRequest(RequestInput req, ByteBuf out) { encodeParameter(input, out); } out.writeInt(-1); // End of List - req.setCached(true); } private void encodeParameter(InputParameter parameter, ByteBuf out) { diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/messages/ModelInferenceRequest.java b/frontend/server/src/main/java/org/pytorch/serve/util/messages/ModelInferenceRequest.java index 9a4c73af76..e83b6d95eb 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/messages/ModelInferenceRequest.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/messages/ModelInferenceRequest.java @@ -23,4 +23,10 @@ public void setRequestBatch(List requestBatch) { public void addRequest(RequestInput req) { batch.add(req); } + + public void setCachedInBackend(boolean cached) { + for (RequestInput input : batch) { + input.setCachedInBackend(cached); + } + } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java b/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java index 5717908f0f..0db8e84064 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/messages/RequestInput.java @@ -73,11 +73,11 @@ public void setClientExpireTS(long clientTimeoutInMills) { } } - public boolean isCached() { + public boolean isCachedInBackend() { return cached; } - public void setCached(boolean cached) { + public void setCachedInBackend(boolean cached) { this.cached = cached; } } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index 90f294e5cf..178bbb91a2 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -37,6 +37,7 @@ import org.pytorch.serve.util.codec.ModelResponseDecoder; import org.pytorch.serve.util.messages.BaseModelRequest; import org.pytorch.serve.util.messages.InputParameter; +import org.pytorch.serve.util.messages.ModelInferenceRequest; import org.pytorch.serve.util.messages.ModelWorkerResponse; import org.pytorch.serve.util.messages.RequestInput; import org.pytorch.serve.util.messages.WorkerCommands; @@ -208,6 +209,9 @@ public void run() { for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) { backendChannel.get(i).writeAndFlush(req).sync(); } + if (req instanceof ModelInferenceRequest) { + ((ModelInferenceRequest) req).setCachedInBackend(true); + } ModelWorkerResponse reply = null; @@ -313,6 +317,7 @@ public void run() { i++) { backendChannel.get(i).disconnect(); } + backendChannel.clear(); currentThread.set(null); Integer exitValue = lifeCycle.getExitValue(); @@ -462,6 +467,7 @@ public void shutdown() { backendChannel.get(i).close(); } } + backendChannel.clear(); lifeCycle.terminateIOStreams(); Thread thread = currentThread.getAndSet(null); if (thread != null) { diff --git a/kubernetes/EKS/README.md b/kubernetes/EKS/README.md index 14a7b656fe..c932f5e914 100644 --- a/kubernetes/EKS/README.md +++ b/kubernetes/EKS/README.md @@ -506,8 +506,8 @@ ```yaml - inference_address=http://127.0.0.1:8080 - management_address=http://127.0.0.1:8081 + inference_address=http://0.0.0.0:8080 + management_address=http://0.0.0.0:8081 NUM_WORKERS=1 number_of_gpu=1 number_of_netty_threads=32 diff --git a/kubernetes/examples/FasterTransformer_HuggingFace_Bert.md b/kubernetes/examples/FasterTransformer_HuggingFace_Bert.md index 53f9c49827..7d1b696e0b 100644 --- a/kubernetes/examples/FasterTransformer_HuggingFace_Bert.md +++ b/kubernetes/examples/FasterTransformer_HuggingFace_Bert.md @@ -33,9 +33,9 @@ docker cp :/workspace/serve/examples/FasterTransformer_HuggingFace ## Create config.properties ```bash -inference_address=http://127.0.0.1:8080 -management_address=http://127.0.0.1:8081 -metrics_address=http://127.0.0.1:8082 +inference_address=http://0.0.0.0:8080 +management_address=http://0.0.0.0:8081 +metrics_address=http://0.0.0.0:8082 NUM_WORKERS=1 number_of_gpu=1 install_py_dep_per_model=true diff --git a/kubernetes/kserve/README.md b/kubernetes/kserve/README.md index f439bd7ce7..cf54a6ce73 100644 --- a/kubernetes/kserve/README.md +++ b/kubernetes/kserve/README.md @@ -109,9 +109,9 @@ torch-model-archiver --model-name mnist_kf --version 1.0 --model-file examples/i - Step - 2 : Create a config.properties file and place the contents like below: ```bash -inference_address=http://127.0.0.1:8085 -management_address=http://127.0.0.1:8081 -metrics_address=http://127.0.0.1:8082 +inference_address=http://0.0.0.0:8085 +management_address=http://0.0.0.0:8081 +metrics_address=http://0.0.0.0:8082 grpc_inference_port=7070 grpc_management_port=7071 enable_envvars_config=true diff --git a/kubernetes/kserve/build_upload_release.py b/kubernetes/kserve/build_upload_release.py index d10ae8533f..55183c7a03 100644 --- a/kubernetes/kserve/build_upload_release.py +++ b/kubernetes/kserve/build_upload_release.py @@ -43,7 +43,7 @@ f"{organization}/torchserve-kfs:{check_ts_version()}", f"{organization}/torchserve-kfs:{check_ts_version()}-gpu", ]: - os.system(f"docker push {image}") + try_and_handle(f"docker push {image}", dry_run) # Cleanup built images if args.cleanup: diff --git a/kubernetes/kserve/config.properties b/kubernetes/kserve/config.properties index 91fbb7483b..422e53d138 100644 --- a/kubernetes/kserve/config.properties +++ b/kubernetes/kserve/config.properties @@ -1,7 +1,7 @@ #Sample config.properties. In production config.properties at /mnt/models/config/config.properties will be used -inference_address=http://127.0.0.1:8085 -management_address=http://127.0.0.1:8085 -metrics_address=http://127.0.0.1:8082 +inference_address=http://0.0.0.0:8085 +management_address=http://0.0.0.0:8085 +metrics_address=http://0.0.0.0:8082 grpc_inference_port=7070 grpc_management_port=7071 enable_envvars_config=true diff --git a/kubernetes/kserve/examples/mnist/MNIST.md b/kubernetes/kserve/examples/mnist/MNIST.md new file mode 100644 index 0000000000..24efd2a2bc --- /dev/null +++ b/kubernetes/kserve/examples/mnist/MNIST.md @@ -0,0 +1,164 @@ +# Digit recognition model with MNIST dataset using a Kubernetes cluster + +In this example, we show how to use a pre-trained custom MNIST model to perform real time Digit recognition with TorchServe. +We will be serving the model using Kserve deployed using [minikube](https://minikube.sigs.k8s.io/docs/start/). + +The inference service would return the digit inferred by the model in the input image. + + +## Install kserve + +Start minikube cluster + +``` +minikube start +``` + +For this example, we need to git clone [kserve](https://github.com/kserve/kserve) +Run the commands given in following steps from the parent directory of the root of the repository. For example, if you cloned the repository into /home/my_path/kserve, run the steps from /home/my_path/kserve + +Run the following for quick install of kserve +``` +./hack/quick_install.sh +``` + +Make sure kserve is installed on minikube cluster using + +``` +kubectl get pods -n kserve +``` + +This should result in +``` +NAME READY STATUS RESTARTS AGE +kserve-controller-manager-57574b4878-rnsjn 2/2 Running 0 17s +``` + +TorchServe supports KServe V1 and V2 protocol. We show how to deploy with both for Mnist. + +## KServe V1 protocol + +Deploy `InferenceService` with Kserve V1 protocol + +``` +kubectl apply -f docs/samples/v1beta1/torchserve/v1/torchserve.yaml +``` + +results in + +``` +inferenceservice.serving.kserve.io/torchserve created +``` + +We need to wait till the pod is up + +``` +kubectl get pods +NAME READY STATUS RESTARTS AGE +torchserve-predictor-00001-deployment-8d66f9c-dkdhr 2/2 Running 0 8m19s +``` + +We need to set the following + +``` +MODEL_NAME=mnist +SERVICE_HOSTNAME=$(kubectl get inferenceservice torchserve -o jsonpath='{.status.url}' | cut -d "/" -f 3) +``` + +``` +export INGRESS_HOST=localhost +export INGRESS_PORT=8080 +``` + +``` +INGRESS_GATEWAY_SERVICE=$(kubectl get svc --namespace istio-system --selector="app=istio-ingressgateway" --output jsonpath='{.items[0].metadata.name}') +kubectl port-forward --namespace istio-system svc/${INGRESS_GATEWAY_SERVICE} 8080:80 & +``` + +Make an inference request + +``` +curl -H "Content-Type: application/json" -H "Host: ${SERVICE_HOSTNAME}" http://${INGRESS_HOST}:${INGRESS_PORT}/v1/models/${MODEL_NAME}:predict -d @./docs/samples/v1beta1/torchserve/v1/mnist.json +``` + +Expected output is + +``` +{"predictions":[2]} +``` + +## KServe V2 protocol + +Deploy `InferenceService` with Kserve V2 protocol + +``` +kubectl apply -f docs/samples/v1beta1/torchserve/v2/mnist.yaml +``` + +results in + +``` +inferenceservice.serving.kserve.io/torchserve-mnist-v2 created +``` + +We need to check the pod is running with + +``` +kubectl get pods +NAME READY STATUS RESTARTS AGE +torchserve-mnist-v2-predictor-00001-deployment-6c8c684dcb-4mfmr 2/2 Running 0 2m37s +``` + +Inspecting the logs of the pods to check the version of TorchServe + +``` +kubectl logs torchserve-mnist-v2-predictor-00001-deployment-6c8c684dcb-4mfmr +Defaulted container "kserve-container" out of: kserve-container, queue-proxy, storage-initializer (init) +WARNING: sun.reflect.Reflection.getCallerClass is not supported. This will impact performance. +2023-10-12T20:50:39,466 [WARN ] main org.pytorch.serve.util.ConfigManager - Your torchserve instance can access any URL to load models. When deploying to production, make sure to limit the set of allowed_urls in config.properties +2023-10-12T20:50:39,468 [INFO ] main org.pytorch.serve.servingsdk.impl.PluginsManager - Initializing plugins manager... +2023-10-12T20:50:39,659 [INFO ] main org.pytorch.serve.metrics.configuration.MetricConfiguration - Successfully loaded metrics configuration from /home/venv/lib/python3.9/site-packages/ts/configs/metrics.yaml +2023-10-12T20:50:39,779 [INFO ] main org.pytorch.serve.ModelServer - +Torchserve version: 0.8.2 +TS Home: /home/venv/lib/python3.9/site-packages +Current directory: /home/model-server +Temp directory: /home/model-server/tmp +Metrics config path: /home/venv/lib/python3.9/site-packages/ts/configs/metrics.yaml + +``` + +We need to set the following + +``` +MODEL_NAME=mnist +SERVICE_HOSTNAME=$(kubectl get inferenceservice torchserve-mnist-v2 -o jsonpath='{.status.url}' | cut -d "/" -f 3) +``` + +``` +export INGRESS_HOST=localhost +export INGRESS_PORT=8080 +``` + +``` +INGRESS_GATEWAY_SERVICE=$(kubectl get svc --namespace istio-system --selector="app=istio-ingressgateway" --output jsonpath='{.items[0].metadata.name}') +kubectl port-forward --namespace istio-system svc/${INGRESS_GATEWAY_SERVICE} 8080:80 & +``` + +Make an inference request with tensor input + +``` +curl -v -H "Content-Type: application/json" -H "Host: ${SERVICE_HOSTNAME}" http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}/infer -d @./docs/samples/v1beta1/torchserve/v2/tensor_conv/mnist_v2.json +``` + +Expected output is + +``` +{"model_name":"mnist","model_version":null,"id":"d3b15cad-50a2-4eaf-80ce-8b0a428bd298","parameters":null,"outputs":[{"name":"input-0","shape":[1],"datatype":"INT64","parameters":null,"data":[1]}]} +``` + +## Stop and Delete the cluster + +``` +minikube stop +minikube delete +``` diff --git a/kubernetes/kserve/image_transformer/transformer.Dockerfile b/kubernetes/kserve/image_transformer/transformer.Dockerfile index 5399b88413..88e4aed9c5 100644 --- a/kubernetes/kserve/image_transformer/transformer.Dockerfile +++ b/kubernetes/kserve/image_transformer/transformer.Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.7-slim +FROM python:3.8.18-slim ARG BRANCH_NAME_KF=master RUN apt-get update \ diff --git a/kubernetes/kserve/tests/configs/mnist_v1_cpu.yaml b/kubernetes/kserve/tests/configs/mnist_v1_cpu.yaml new file mode 100644 index 0000000000..8c6b044244 --- /dev/null +++ b/kubernetes/kserve/tests/configs/mnist_v1_cpu.yaml @@ -0,0 +1,9 @@ +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: "torchserve" +spec: + predictor: + pytorch: + storageUri: gs://kfserving-examples/models/torchserve/image_classifier/v1 + image: pytorch/torchserve-kfs-nightly:latest-cpu diff --git a/kubernetes/kserve/tests/configs/mnist_v2_cpu.yaml b/kubernetes/kserve/tests/configs/mnist_v2_cpu.yaml new file mode 100644 index 0000000000..f60efc14e0 --- /dev/null +++ b/kubernetes/kserve/tests/configs/mnist_v2_cpu.yaml @@ -0,0 +1,10 @@ +apiVersion: serving.kserve.io/v1beta1 +kind: InferenceService +metadata: + name: "torchserve-mnist-v2" +spec: + predictor: + pytorch: + protocolVersion: v2 + storageUri: gs://kfserving-examples/models/torchserve/image_classifier/v2 + image: pytorch/torchserve-kfs-nightly:latest-cpu diff --git a/kubernetes/kserve/tests/scripts/test_mnist.sh b/kubernetes/kserve/tests/scripts/test_mnist.sh new file mode 100755 index 0000000000..e9b012a757 --- /dev/null +++ b/kubernetes/kserve/tests/scripts/test_mnist.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash + +set -o errexit -o nounset -o pipefail + +function start_minikube_cluster() { + echo "Removing any previous Kubernetes cluster" + minikube delete + echo "Starting Kubernetes cluster" + minikube start +} + +function install_kserve() { + echo "Install Kserve" + cd $GITHUB_WORKSPACE/kserve + ./hack/quick_install.sh + echo "Waiting for Kserve pod to come up ..." + wait_for_kserve_pod 300 5 +} + +function deploy_cluster() { + echo "Deploying the cluster" + cd $GITHUB_WORKSPACE + kubectl apply -f "$1" + echo "Waiting for pod to come up..." + wait_for_pod_running "$2" 120 + echo "Check status of the pod" + kubectl get pods + kubectl describe pod "$2" +} + +function make_cluster_accessible() { + SERVICE_NAME="$1" + URL="$2" + wait_for_inference_service 300 5 "$1" + SERVICE_HOSTNAME=$(kubectl get inferenceservice ${SERVICE_NAME} -o jsonpath='{.status.url}' | cut -d "/" -f 3) + wait_for_port_forwarding 5 + echo "Make inference request" + PREDICTION=$(curl -H "Content-Type: application/json" -H "Host: ${SERVICE_HOSTNAME}" ${URL} -d @"$3") + EXPECTED="$4" + if [ "${PREDICTION}" = "${EXPECTED}" ]; then + echo "✓ SUCCESS" + else + echo "✘ Test failed: Prediction: ${PREDICTION}, expected ${EXPECTED}." + delete_minikube_cluster + exit 1 + fi +} + +function delete_minikube_cluster() { + echo "Delete cluster" + minikube delete +} + +function wait_for_inference_service() { + echo "Wait for inference service to be ready" + max_wait_time="$1" + interval="$2" + SERVICE_NAME="$3" + start_time=$(date +%s) + while true; do + service_status=$(kubectl get inferenceservice ${SERVICE_NAME} -o jsonpath='{.status.conditions[?(@.type=="Ready")].status}') + if [[ "$service_status" == "True" ]]; then + break + fi + current_time=$(date +%s) + if (( current_time - start_time >= max_wait_time )); then + echo "Timeout waiting for inference service to come up." + delete_minikube_cluster + exit 1 + fi + sleep "$interval" + done +} +function wait_for_kserve_pod() { + max_wait_time="$1" + interval="$2" + start_time=$(date +%s) + while true; do + kserve_pod_status=$(kubectl get pods -n kserve --no-headers -o custom-columns=":status.phase") + if [[ "$kserve_pod_status" == "Running" ]]; then + break + fi + current_time=$(date +%s) + if (( current_time - start_time >= max_wait_time )); then + echo "Timeout waiting for Kserve pod to come up." + delete_minikube_cluster + exit 1 + fi + sleep "$interval" + done +} + +function wait_for_pod_running() { + pod_name="$1" + max_wait_time="$2" + interval=5 + start_time=$(date +%s) + while true; do + sleep "$interval" + pod_description=$(kubectl describe pod "$pod_name") + status_line=$(echo "$pod_description" | grep -E "Status:") + pod_status=$(echo "$status_line" | awk '{print $2}') + if [[ "$pod_status" == "Running" ]]; then + break + fi + current_time=$(date +%s) + if (( current_time - start_time >= max_wait_time )); then + echo "Timeout waiting for pod $pod_name to become Running." + delete_minikube_cluster + exit 1 + fi + done +} + +function wait_for_port_forwarding() { + echo "Wait for ports to be in forwarding" + interval="$1" + start_time=$(date +%s) + INGRESS_GATEWAY_SERVICE=$(kubectl get svc --namespace istio-system --selector="app=istio-ingressgateway" --output jsonpath='{.items[0].metadata.name}') + kubectl port-forward --namespace istio-system svc/${INGRESS_GATEWAY_SERVICE} 8080:80 & + sleep "$interval" +} + +export INGRESS_HOST=localhost +export INGRESS_PORT=8080 +export MODEL_NAME=mnist + +start_minikube_cluster +install_kserve + +echo "MNIST KServe V2 test begin" +deploy_cluster "kubernetes/kserve/tests/configs/mnist_v2_cpu.yaml" "torchserve-mnist-v2-predictor" +URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v2/models/${MODEL_NAME}/infer" +make_cluster_accessible "torchserve-mnist-v2" ${URL} "./kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_tensor.json" '{"model_name":"mnist","model_version":null,"id":"d3b15cad-50a2-4eaf-80ce-8b0a428bd298","parameters":null,"outputs":[{"name":"input-0","shape":[1],"datatype":"INT64","parameters":null,"data":[1]}]}' + +echo "MNIST KServe V1 test begin" +deploy_cluster "kubernetes/kserve/tests/configs/mnist_v1_cpu.yaml" "torchserve-predictor" +URL="http://${INGRESS_HOST}:${INGRESS_PORT}/v1/models/${MODEL_NAME}:predict" +make_cluster_accessible "torchserve" ${URL} "./kubernetes/kserve/kf_request_json/v1/mnist.json" '{"predictions":[2]}' + +delete_minikube_cluster diff --git a/requirements/developer.txt b/requirements/developer.txt index bf09ab8b69..a087314a8a 100644 --- a/requirements/developer.txt +++ b/requirements/developer.txt @@ -7,7 +7,7 @@ pytest-cov==4.1.0 grpcio==1.54.2 protobuf==4.23.1 grpcio-tools==1.54.2 -transformers==4.30.0 +transformers>=4.34.0 pyspelling==2.8.2 pygit2==1.13.1 pre-commit==3.3.2 diff --git a/test/pytest/test_auto_recover.py b/test/pytest/test_auto_recover.py new file mode 100644 index 0000000000..87bca76c4a --- /dev/null +++ b/test/pytest/test_auto_recover.py @@ -0,0 +1,180 @@ +import json +import platform +import shutil +from argparse import Namespace +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import requests +import test_utils + +CURR_FILE_PATH = Path(__file__).parent +REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent + +MODEL_PY = """ +import torch +import torch.nn as nn + +class Foo(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x +""" + +HANDLER_PY = """ +import time + +from typing import List, Dict, Any, Tuple +from ts.context import Context + + +class FailingModel(object): + def __init__(self) -> None: + pass + + def initialize(self, context: Context) -> None: + print(f"[xxx] Model initialization ... !!") + self.initialized = True + print(f"[xxx] Model initialization ... DONE !!") + + def handle(self, data: List[Dict[str, Any]], context: Context): + self.context = context + + output = list() + for idx, row in enumerate(data): + # run + print(f"[xxx] run ... !!") + time.sleep(5) + print(f"[xxx] run ... DONE !!") + output.append(f"sample output {idx}") + return output +""" + +CONFIG_PROPERTIES = """ +default_response_timeout=2 +""" + + +@pytest.fixture(scope="module") +def model_name(): + yield "tp_model" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return Path(tmp_path_factory.mktemp(model_name)) + + +@pytest.fixture(scope="module") +def torchserve(model_store, work_dir): + test_utils.torchserve_cleanup() + + config_properties_file = work_dir / "config.properties" + config_properties_file.write_text(CONFIG_PROPERTIES) + + pipe = test_utils.start_torchserve( + model_store=model_store, + no_config_snapshots=True, + gen_mar=False, + snapshot_file=config_properties_file.as_posix(), + ) + + yield pipe + + test_utils.torchserve_cleanup() + + +@pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name): + mar_file_path = work_dir.joinpath(model_name + ".mar") + + model_py_file = work_dir / "model.py" + model_py_file.write_text(MODEL_PY) + + handler_py_file = work_dir / "handler.py" + handler_py_file.write_text(HANDLER_PY) + + args = Namespace( + model_name=model_name, + version="1.0", + serialized_file=None, + model_file=model_py_file.as_posix(), + handler=handler_py_file.as_posix(), + extra_files=None, + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=None, + ) + + mock = MagicMock() + mock.parse_args = MagicMock(return_value=args) + with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + mar_file_path.unlink(missing_ok=True) + + +@pytest.fixture(scope="module", name="model_name") +def register_model(mar_file_path, model_store, torchserve): + """ + Register the model in torchserve + """ + shutil.copy(mar_file_path, model_store) + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + params = ( + ("model_name", model_name), + ("url", file_name), + ("initial_workers", "1"), + ("synchronous", "true"), + ("batch_size", "1"), + ) + + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name, torchserve + + test_utils.unregister_model(model_name) + + +@pytest.mark.skipif( + platform.system() != "Linux", reason="Skipping test on non-Linux system" +) +def test_tp_inference(model_name): + """ + Full circle test with torchserve + """ + + model_name, pipe = model_name + + response = requests.post( + url=f"http://localhost:8080/predictions/{model_name}", data=json.dumps(42) + ) + assert response.status_code == 500 + + logs = [] + for _ in range(100): + logs.append(pipe.get()) + if "Auto recovery succeeded, reset recoveryStartTS" in logs[-1]: + break + + assert any("Model initialization ... DONE" in l for l in logs) + assert any("Number or consecutive unsuccessful inference 1" in l for l in logs) + assert any("Worker disconnected" in l for l in logs) + assert any("Retry worker" in l for l in logs) + assert any("Auto recovery start timestamp" in l for l in logs) + assert not any("Auto recovery failed again" in l for l in logs) diff --git a/test/pytest/test_parallelism.py b/test/pytest/test_parallelism.py new file mode 100644 index 0000000000..04183ec01f --- /dev/null +++ b/test/pytest/test_parallelism.py @@ -0,0 +1,148 @@ +import json +import platform +import shutil +from argparse import Namespace +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import requests +import test_utils + +CURR_FILE_PATH = Path(__file__).parent +REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent + +MODEL_PY = """ +import torch +import torch.nn as nn + +class Foo(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + torch.distributed.all_reduce(x) + return x +""" + +HANDLER_PY = """ +import os +import torch +from ts.torch_handler.base_handler import BaseHandler + +class FooHandler(BaseHandler): + def initialize(self, ctx): + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group("gloo") + torch.set_default_device("cpu") + super().initialize(ctx) + + def preprocess(self, data): + return torch.as_tensor(int(data[0].get('body').decode('utf-8')), device=self.device) + + def postprocess(self, x): + return [x.item()] +""" + +MODEL_CONFIG_YAML = f""" +#frontend settings +parallelType: "tp" +deviceType: "cpu" + +torchrun: + nproc-per-node: 4 +""" + + +@pytest.fixture(scope="module") +def model_name(): + yield "tp_model" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return Path(tmp_path_factory.mktemp(model_name)) + + +@pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name): + mar_file_path = work_dir.joinpath(model_name + ".mar") + + model_config_yaml_file = work_dir / "model_config.yaml" + model_config_yaml_file.write_text(MODEL_CONFIG_YAML) + + model_py_file = work_dir / "model.py" + model_py_file.write_text(MODEL_PY) + + handler_py_file = work_dir / "handler.py" + handler_py_file.write_text(HANDLER_PY) + + args = Namespace( + model_name=model_name, + version="1.0", + serialized_file=None, + model_file=model_py_file.as_posix(), + handler=handler_py_file.as_posix(), + extra_files=None, + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=model_config_yaml_file.as_posix(), + ) + + mock = MagicMock() + mock.parse_args = MagicMock(return_value=args) + with patch("archiver.ArgParser.export_model_args_parser", return_value=mock): + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + mar_file_path.unlink(missing_ok=True) + + +@pytest.fixture(scope="module", name="model_name") +def register_model(mar_file_path, model_store, torchserve): + """ + Register the model in torchserve + """ + shutil.copy(mar_file_path, model_store) + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + params = ( + ("model_name", model_name), + ("url", file_name), + ("initial_workers", "1"), + ("synchronous", "true"), + ("batch_size", "1"), + ) + + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name + + test_utils.unregister_model(model_name) + + +@pytest.mark.skipif( + platform.system() != "Linux", reason="Skipping test on non-Linux system" +) +def test_tp_inference(model_name): + """ + Full circle test with torchserve + """ + + response = requests.post( + url=f"http://localhost:8080/predictions/{model_name}", data=json.dumps(42) + ) + + assert int(response.text) == 4 * 42 + + assert response.status_code == 200 diff --git a/ts_scripts/api_utils.py b/ts_scripts/api_utils.py index 99398ef17c..cdfaccac9a 100755 --- a/ts_scripts/api_utils.py +++ b/ts_scripts/api_utils.py @@ -367,7 +367,8 @@ def trigger_all(): exit_code9 = trigger_https_tests_kfv2() exit_code10 = trigger_explanation_tests() exit_code11 = trigger_workflow_tests() - exit_code12 = trigger_workflow_inference_tests() + # Skipping as this test is flaky + # exit_code12 = trigger_workflow_inference_tests() return ( 1 if any( @@ -384,7 +385,6 @@ def trigger_all(): exit_code9, exit_code10, exit_code11, - exit_code12, ] ) else 0 diff --git a/ts_scripts/spellcheck_conf/wordlist.txt b/ts_scripts/spellcheck_conf/wordlist.txt index f8fe15e126..b4fb8bc4a6 100644 --- a/ts_scripts/spellcheck_conf/wordlist.txt +++ b/ts_scripts/spellcheck_conf/wordlist.txt @@ -162,7 +162,10 @@ CN CORS EventLoopGroup EventLoops +CPUs GPUs +TPU +TPUs JVM MaxDirectMemorySize OU @@ -1118,3 +1121,10 @@ quantized Chatbot LLM bitsandbytes +Datadog +Trn +oneAPI +Naver +FlashAttention +GenAI +prem