Skip to content

Commit

Permalink
feat: update setup.py for including Ray v2.33, restrict RoV predictio…
Browse files Browse the repository at this point in the history
…n to 2.9.3 for now

PiperOrigin-RevId: 673086035
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 10, 2024
1 parent 6cc5fa2 commit 71c6f3c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 4 deletions.
14 changes: 14 additions & 0 deletions google/cloud/aiplatform/vertex_ray/predict/sklearn/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,15 @@ def register_sklearn(
Raises:
ValueError: Invalid Argument.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to upload Sklearn"
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
)

artifact_uri = artifact_uri or initializer.global_config.staging_bucket
predict_utils.validate_artifact_uri(artifact_uri)
display_model_name = (
Expand Down Expand Up @@ -122,11 +130,17 @@ def _get_estimator_from(
ValueError: Invalid Argument.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""

ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to convert a Sklearn"
" checkpoint to sklearn estimator on Vertex yet. Please use Ray 2.9.3."
)

try:
return checkpoint.get_model()
Expand Down
6 changes: 6 additions & 0 deletions google/cloud/aiplatform/vertex_ray/predict/torch/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ def get_pytorch_model_from(
ModuleNotFoundError: PyTorch isn't installed.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray on Vertex does not support Ray version {ray_version} to"
" convert PyTorch model artifacts yet. Please use Ray 2.9.3."
)

try:
return checkpoint.get_model()
Expand Down
14 changes: 14 additions & 0 deletions google/cloud/aiplatform/vertex_ray/predict/xgboost/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,15 @@ def register_xgboost(
Raises:
ValueError: Invalid Argument.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to upload XGBoost"
" model to Vertex Model Registry yet. Please use Ray 2.9.3."
)

artifact_uri = artifact_uri or initializer.global_config.staging_bucket
predict_utils.validate_artifact_uri(artifact_uri)
display_model_name = (
Expand Down Expand Up @@ -136,10 +144,16 @@ def _get_xgboost_model_from(
ModuleNotFoundError: XGBoost isn't installed.
RuntimeError: Model not found.
RuntimeError: Ray version 2.4 is not supported.
RuntimeError: Only Ray version 2.9.3 is supported.
"""
ray_version = ray.__version__
if ray_version == "2.4.0":
raise RuntimeError(_V2_4_WARNING_MESSAGE)
if ray_version != "2.9.3":
raise RuntimeError(
f"Ray version {ray_version} is not supported to convert a XGBoost"
" checkpoint to XGBoost model on Vertex yet. Please use Ray 2.9.3."
)

try:
# This works for Ray v2.5
Expand Down
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,16 +101,19 @@
preview_extra_require = []

ray_extra_require = [
# Cluster only supports 2.9.3. Keep 2.4.0 for our testing environment.
# Cluster only supports 2.9.3 and 2.33.0. Keep 2.4.0 for our testing environment.
# Note that testing is submiting a job in a cluster with Ray 2.9.3 remotely.
(
"ray[default] >= 2.4, <= 2.9.3,!= 2.5.*,!= 2.6.*,!= 2.7.*,!="
" 2.8.*,!=2.9.0,!=2.9.1,!=2.9.2; python_version<'3.11'"
"ray[default] >= 2.4, <= 2.33.0,!= 2.5.*,!= 2.6.*,!= 2.7.*,!="
" 2.8.*,!=2.9.0,!=2.9.1,!=2.9.2, !=2.10.*, !=2.11.*, !=2.12.*, !=2.13.*, !="
" 2.14.*, !=2.15.*, !=2.16.*, !=2.17.*, !=2.18.*, !=2.19.*, !=2.20.*, !="
" 2.21.*, !=2.22.*, !=2.23.*, !=2.24.*, !=2.25.*, !=2.26.*, !=2.27.*, !="
" 2.28.*, !=2.29.*, !=2.30.*, !=2.31.*, !=2.32.*; python_version<'3.11'"
),
# To avoid ImportError: cannot import name 'packaging' from 'pkg)resources'
"setuptools < 70.0.0",
# Ray Data v2.4 in Python 3.11 is broken, but got fixed in Ray v2.5.
"ray[default] >= 2.5, <= 2.9.3; python_version=='3.11'",
"ray[default] >= 2.5, <= 2.33.0; python_version=='3.11'",
"google-cloud-bigquery-storage",
"google-cloud-bigquery",
"pandas >= 1.0.0, < 2.2.0",
Expand Down

0 comments on commit 71c6f3c

Please sign in to comment.