Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions kubeflow/trainer/backends/kubernetes/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,26 @@ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements:

def get_custom_trainer(
env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None,
pip_index_urls: Optional[list[str]] = constants.DEFAULT_PIP_INDEX_URLS,
packages_to_install: list[str] = ["torch", "numpy"],
) -> models.TrainerV1alpha1Trainer:
"""
Get the custom trainer for the TrainJob.
"""
pip_command = [f"--index-url {pip_index_urls[0]}"]
pip_command.extend([f"--extra-index-url {repo}" for repo in pip_index_urls[1:]])
pip_command = " ".join(pip_command)

packages_command = " ".join(packages_to_install)
return models.TrainerV1alpha1Trainer(
command=[
"bash",
"-c",
'\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip '
"|| python -m ensurepip --user || apt-get install python-pip"
"\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
" --no-warn-script-location --index-url https://pypi.org/simple "
"torch numpy \n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
f" --no-warn-script-location {pip_command} {packages_command}"
"\n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
'print("Hello World"),\n\n<lambda>('
"{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" "
'"$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"',
Expand Down Expand Up @@ -723,14 +729,17 @@ def test_get_runtime_packages(trainer_client, test_case):
func=lambda: print("Hello World"),
func_args={"learning_rate": 0.001, "batch_size": 32},
packages_to_install=["torch", "numpy"],
pip_index_url=constants.DEFAULT_PIP_INDEX_URL,
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
num_nodes=2,
)
},
expected_output=get_train_job(
runtime_name=TORCH_RUNTIME,
train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
train_job_trainer=get_custom_trainer(),
train_job_trainer=get_custom_trainer(
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
packages_to_install=["torch", "numpy"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: could be other packages since those are already in the runtime image.

),
),
),
TestCase(
Expand All @@ -741,7 +750,7 @@ def test_get_runtime_packages(trainer_client, test_case):
func=lambda: print("Hello World"),
func_args={"learning_rate": 0.001, "batch_size": 32},
packages_to_install=["torch", "numpy"],
pip_index_url=constants.DEFAULT_PIP_INDEX_URL,
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
num_nodes=2,
env={
"TEST_ENV": "test_value",
Expand All @@ -757,6 +766,8 @@ def test_get_runtime_packages(trainer_client, test_case):
models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"),
models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"),
],
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
packages_to_install=["torch", "numpy"],
),
),
),
Expand Down Expand Up @@ -788,6 +799,7 @@ def test_get_runtime_packages(trainer_client, test_case):
},
expected_error=ValueError,
),

],
)
def test_train(trainer_client, test_case):
Expand Down
5 changes: 3 additions & 2 deletions kubeflow/trainer/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@
NODE,
)

# The default PIP index URL to download Python packages.
DEFAULT_PIP_INDEX_URL = os.getenv("DEFAULT_PIP_INDEX_URL", "https://pypi.org/simple")
# Handle environment variable for multiple URLs (comma-separated).
# The first URL will be the index-url, and remaining ones are extra-index-urls.
DEFAULT_PIP_INDEX_URLS = os.getenv("DEFAULT_PIP_INDEX_URLS", "https://pypi.org/simple").split(",")

# The exec script to embed training function into container command.
# __ENTRYPOINT__ depends on the MLPolicy, func_code and func_file is substituted in the `train` API.
Expand Down
8 changes: 6 additions & 2 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class CustomTrainer:
func_args (`Optional[Dict]`): The arguments to pass to the function.
packages_to_install (`Optional[List[str]]`):
A list of Python packages to install before running the function.
pip_index_url (`Optional[str]`): The PyPI URL from which to install Python packages.
pip_index_urls (`list[str]`): The PyPI URLs from which to install
Python packages. The first URL will be the index-url, and remaining ones
are extra-index-urls.
num_nodes (`Optional[int]`): The number of nodes to use for training.
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes.
Expand All @@ -41,7 +43,9 @@ class CustomTrainer:
func: Callable
func_args: Optional[dict] = None
packages_to_install: Optional[list[str]] = None
pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL
pip_index_urls: list[str] = field(
default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS)
)
num_nodes: Optional[int] = None
resources_per_node: Optional[dict] = None
env: Optional[dict[str, str]] = None
Expand Down
26 changes: 15 additions & 11 deletions kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,28 +255,32 @@ def get_resources_per_node(

def get_script_for_python_packages(
packages_to_install: list[str],
pip_index_url: str,
pip_index_urls: list[str],
is_mpi: bool,
) -> str:
"""
Get init script to install Python packages from the given pip index URL.
Get init script to install Python packages from the given pip index URLs.
"""
# packages_str = " ".join([str(package) for package in packages_to_install])
packages_str = " ".join(packages_to_install)

# first url will be the index-url.
options = [f"--index-url {pip_index_urls[0]}"]
options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:])
# For the OpenMPI, the packages must be installed for the mpiuser.
if is_mpi:
options.append("--user")

script_for_python_packages = textwrap.dedent(
"""
if ! [ -x "$(command -v pip)" ]; then
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
fi

PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
--no-warn-script-location --index-url {} {} {}
--no-warn-script-location {} {}
""".format(
pip_index_url,
" ".join(options),
packages_str,
# For the OpenMPI, the packages must be installed for the mpiuser.
"--user" if is_mpi else "",
)
)

Expand All @@ -287,8 +291,8 @@ def get_command_using_train_func(
runtime: types.Runtime,
train_func: Callable,
train_func_parameters: Optional[dict[str, Any]],
pip_index_url: str,
packages_to_install: Optional[list[str]] = None,
pip_index_urls: list[str],
packages_to_install: Optional[list[str]],
) -> list[str]:
"""
Get the Trainer container command from the given training function and parameters.
Expand Down Expand Up @@ -333,7 +337,7 @@ def get_command_using_train_func(
if packages_to_install:
install_packages = get_script_for_python_packages(
packages_to_install,
pip_index_url,
pip_index_urls,
is_mpi,
)

Expand Down Expand Up @@ -374,7 +378,7 @@ def get_trainer_crd_from_custom_trainer(
runtime,
trainer.func,
trainer.func_args,
trainer.pip_index_url,
trainer.pip_index_urls,
trainer.packages_to_install,
)

Expand Down
126 changes: 126 additions & 0 deletions kubeflow/trainer/utils/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright 2025 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any, Dict

import pytest

from kubeflow.trainer.utils import utils
from kubeflow.trainer.constants import constants


@dataclass
class TestCase:
name: str
config: Dict[str, Any]
expected_output: str
__test__ = False


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="multiple pip index URLs",
config={
"packages_to_install": ["torch", "numpy", "custom-package"],
"pip_index_urls": [
"https://pypi.org/simple",
"https://private.repo.com/simple",
"https://internal.company.com/simple"
],
"is_mpi": False
},
expected_output=(
'\nif ! [ -x "$(command -v pip)" ]; then\n'
' python -m ensurepip || python -m ensurepip --user || '
'apt-get install python-pip\n'
'fi\n\n'
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
'--no-warn-script-location --index-url https://pypi.org/simple '
'--extra-index-url https://private.repo.com/simple '
'--extra-index-url https://internal.company.com/simple '
'torch numpy custom-package\n'
)
),
TestCase(
name="single pip index URL (backward compatibility)",
config={
"packages_to_install": ["torch", "numpy", "custom-package"],
"pip_index_urls": ["https://pypi.org/simple"],
"is_mpi": False
},
expected_output=(
'\nif ! [ -x "$(command -v pip)" ]; then\n'
' python -m ensurepip || python -m ensurepip --user || '
'apt-get install python-pip\n'
'fi\n\n'
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
'--no-warn-script-location --index-url https://pypi.org/simple '
'torch numpy custom-package\n'
)
),
TestCase(
name="multiple pip index URLs with MPI",
config={
"packages_to_install": ["torch", "numpy", "custom-package"],
"pip_index_urls": [
"https://pypi.org/simple",
"https://private.repo.com/simple",
"https://internal.company.com/simple"
],
"is_mpi": True
},
expected_output=(
'\nif ! [ -x "$(command -v pip)" ]; then\n'
' python -m ensurepip || python -m ensurepip --user || '
'apt-get install python-pip\n'
'fi\n\n'
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
'--no-warn-script-location --index-url https://pypi.org/simple '
'--extra-index-url https://private.repo.com/simple '
'--extra-index-url https://internal.company.com/simple '
'--user torch numpy custom-package\n'
)
),
TestCase(
name="default pip index URLs",
config={
"packages_to_install": ["torch", "numpy"],
"pip_index_urls": constants.DEFAULT_PIP_INDEX_URLS,
"is_mpi": False
},
expected_output=(
'\nif ! [ -x "$(command -v pip)" ]; then\n'
' python -m ensurepip || python -m ensurepip --user || '
'apt-get install python-pip\n'
'fi\n\n'
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
f'--no-warn-script-location --index-url '
f'{constants.DEFAULT_PIP_INDEX_URLS[0]} torch numpy\n'
)
),
],
)
def test_get_script_for_python_packages(test_case):
"""Test get_script_for_python_packages with various configurations."""

script = utils.get_script_for_python_packages(
packages_to_install=test_case.config["packages_to_install"],
pip_index_urls=test_case.config["pip_index_urls"],
is_mpi=test_case.config["is_mpi"]
)

assert test_case.expected_output == script