Skip to content

Commit 427b35d

Browse files
feat: Support multiple pip index URLs in CustomTrainer (#79)
* feat: Support multiple pip index URLs in CustomTrainer - Add pip_index_urls parameter to CustomTrainer (list of URLs) - Update constants to support multiple URLs with environment variable - Modify utils to handle multiple index URLs in pip installation - Update tests to use new parameter - Fix mutable default issue with dataclasses.field(default_factory=...) Closes #72 Signed-off-by: wassimbensalem <bswassim@gmail.com> * test: add comprehensive test cases for multiple pip index URLs Signed-off-by: wassimbensalem <bswassim@gmail.com> * refactor: simplify constants for pip index URLs Signed-off-by: wassimbensalem <bswassim@gmail.com> * fix: remove Optional from pip_index_urls docstring to match implementation Signed-off-by: wassimbensalem <bswassim@gmail.com> * test: adjust unit test for get_custom_trainer Signed-off-by: wassimbensalem <bswassim@gmail.com> * fix: add missing Dict import in utils.py Signed-off-by: wassimbensalem <bswassim@gmail.com> * refactor: move test cases to utils_test.py and convert to parametrized format Signed-off-by: wassimbensalem <bswassim@gmail.com> * refactor: update utils test to use exact string matching instead of contains/not_contains checks - Changed TestCase expected_output from Dict[str, Any] to str - Replaced contains/not_contains logic with direct string equality - Updated all test cases to specify complete expected script output - Tests now assert the entire script returned by get_script_for_python_packages Signed-off-by: wassimbensalem <bswassim@gmail.com> * feat: add multiple pip index URLs support with comprehensive tests - Add support for multiple pip index URLs in CustomTrainer - Update constants to handle comma-separated environment variables - Add integration test for multiple pip index URLs flow - Update utils to generate correct pip install commands with --index-url and --extra-index-url - Ensure backward compatibility with single URL usage Signed-off-by: wassimbensalem <bswassim@gmail.com> * refactor: remove redundant multiple pip index URLs test from backend - Remove redundant test case since comprehensive tests already exist in utils test - Keep integration tests focused on core functionality rather than implementation details - Maintain test coverage while reducing duplication Signed-off-by: wassimbensalem <bswassim@gmail.com> * Revert "refactor: remove redundant multiple pip index URLs test from backend" This reverts commit 0ec003d. Signed-off-by: wassimbensalem <bswassim@gmail.com> * refactor: remove duplicated test case from the backend. Signed-off-by: wassimbensalem <bswassim@gmail.com> --------- Signed-off-by: wassimbensalem <bswassim@gmail.com>
1 parent 0f7a988 commit 427b35d

File tree

5 files changed

+167
-20
lines changed

5 files changed

+167
-20
lines changed

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,20 +220,26 @@ def get_resource_requirements() -> models.IoK8sApiCoreV1ResourceRequirements:
220220

221221
def get_custom_trainer(
222222
env: Optional[list[models.IoK8sApiCoreV1EnvVar]] = None,
223+
pip_index_urls: Optional[list[str]] = constants.DEFAULT_PIP_INDEX_URLS,
224+
packages_to_install: list[str] = ["torch", "numpy"],
223225
) -> models.TrainerV1alpha1Trainer:
224226
"""
225227
Get the custom trainer for the TrainJob.
226228
"""
229+
pip_command = [f"--index-url {pip_index_urls[0]}"]
230+
pip_command.extend([f"--extra-index-url {repo}" for repo in pip_index_urls[1:]])
231+
pip_command = " ".join(pip_command)
227232

233+
packages_command = " ".join(packages_to_install)
228234
return models.TrainerV1alpha1Trainer(
229235
command=[
230236
"bash",
231237
"-c",
232238
'\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip '
233239
"|| python -m ensurepip --user || apt-get install python-pip"
234240
"\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
235-
" --no-warn-script-location --index-url https://pypi.org/simple "
236-
"torch numpy \n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
241+
f" --no-warn-script-location {pip_command} {packages_command}"
242+
"\n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
237243
'print("Hello World"),\n\n<lambda>('
238244
"{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" "
239245
'"$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"',
@@ -723,14 +729,17 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
723729
func=lambda: print("Hello World"),
724730
func_args={"learning_rate": 0.001, "batch_size": 32},
725731
packages_to_install=["torch", "numpy"],
726-
pip_index_url=constants.DEFAULT_PIP_INDEX_URL,
732+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
727733
num_nodes=2,
728734
)
729735
},
730736
expected_output=get_train_job(
731737
runtime_name=TORCH_RUNTIME,
732738
train_job_name=TRAIN_JOB_WITH_CUSTOM_TRAINER,
733-
train_job_trainer=get_custom_trainer(),
739+
train_job_trainer=get_custom_trainer(
740+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
741+
packages_to_install=["torch", "numpy"],
742+
),
734743
),
735744
),
736745
TestCase(
@@ -741,7 +750,7 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
741750
func=lambda: print("Hello World"),
742751
func_args={"learning_rate": 0.001, "batch_size": 32},
743752
packages_to_install=["torch", "numpy"],
744-
pip_index_url=constants.DEFAULT_PIP_INDEX_URL,
753+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
745754
num_nodes=2,
746755
env={
747756
"TEST_ENV": "test_value",
@@ -757,6 +766,8 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
757766
models.IoK8sApiCoreV1EnvVar(name="TEST_ENV", value="test_value"),
758767
models.IoK8sApiCoreV1EnvVar(name="ANOTHER_ENV", value="another_value"),
759768
],
769+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
770+
packages_to_install=["torch", "numpy"],
760771
),
761772
),
762773
),
@@ -788,6 +799,7 @@ def test_get_runtime_packages(kubernetes_backend, test_case):
788799
},
789800
expected_error=ValueError,
790801
),
802+
791803
],
792804
)
793805
def test_train(kubernetes_backend, test_case):

kubeflow/trainer/constants/constants.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,9 @@
125125
NODE,
126126
)
127127

128-
# The default PIP index URL to download Python packages.
129-
DEFAULT_PIP_INDEX_URL = os.getenv("DEFAULT_PIP_INDEX_URL", "https://pypi.org/simple")
128+
# Handle environment variable for multiple URLs (comma-separated).
129+
# The first URL will be the index-url, and remaining ones are extra-index-urls.
130+
DEFAULT_PIP_INDEX_URLS = os.getenv("DEFAULT_PIP_INDEX_URLS", "https://pypi.org/simple").split(",")
130131

131132
# The exec script to embed training function into container command.
132133
# __ENTRYPOINT__ depends on the MLPolicy, func_code and func_file is substituted in the `train` API.

kubeflow/trainer/types/types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ class CustomTrainer:
3232
func_args (`Optional[Dict]`): The arguments to pass to the function.
3333
packages_to_install (`Optional[List[str]]`):
3434
A list of Python packages to install before running the function.
35-
pip_index_url (`Optional[str]`): The PyPI URL from which to install Python packages.
35+
pip_index_urls (`list[str]`): The PyPI URLs from which to install
36+
Python packages. The first URL will be the index-url, and remaining ones
37+
are extra-index-urls.
3638
num_nodes (`Optional[int]`): The number of nodes to use for training.
3739
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
3840
env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes.
@@ -41,7 +43,9 @@ class CustomTrainer:
4143
func: Callable
4244
func_args: Optional[dict] = None
4345
packages_to_install: Optional[list[str]] = None
44-
pip_index_url: str = constants.DEFAULT_PIP_INDEX_URL
46+
pip_index_urls: list[str] = field(
47+
default_factory=lambda: list(constants.DEFAULT_PIP_INDEX_URLS)
48+
)
4549
num_nodes: Optional[int] = None
4650
resources_per_node: Optional[dict] = None
4751
env: Optional[dict[str, str]] = None

kubeflow/trainer/utils/utils.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -253,28 +253,32 @@ def get_resources_per_node(
253253

254254
def get_script_for_python_packages(
255255
packages_to_install: list[str],
256-
pip_index_url: str,
256+
pip_index_urls: list[str],
257257
is_mpi: bool,
258258
) -> str:
259259
"""
260-
Get init script to install Python packages from the given pip index URL.
260+
Get init script to install Python packages from the given pip index URLs.
261261
"""
262-
# packages_str = " ".join([str(package) for package in packages_to_install])
263262
packages_str = " ".join(packages_to_install)
264263

264+
# first url will be the index-url.
265+
options = [f"--index-url {pip_index_urls[0]}"]
266+
options.extend(f"--extra-index-url {extra_index_url}" for extra_index_url in pip_index_urls[1:])
267+
# For the OpenMPI, the packages must be installed for the mpiuser.
268+
if is_mpi:
269+
options.append("--user")
270+
265271
script_for_python_packages = textwrap.dedent(
266272
"""
267273
if ! [ -x "$(command -v pip)" ]; then
268274
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
269275
fi
270276
271277
PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
272-
--no-warn-script-location --index-url {} {} {}
278+
--no-warn-script-location {} {}
273279
""".format(
274-
pip_index_url,
280+
" ".join(options),
275281
packages_str,
276-
# For the OpenMPI, the packages must be installed for the mpiuser.
277-
"--user" if is_mpi else "",
278282
)
279283
)
280284

@@ -285,8 +289,8 @@ def get_command_using_train_func(
285289
runtime: types.Runtime,
286290
train_func: Callable,
287291
train_func_parameters: Optional[dict[str, Any]],
288-
pip_index_url: str,
289-
packages_to_install: Optional[list[str]] = None,
292+
pip_index_urls: list[str],
293+
packages_to_install: Optional[list[str]],
290294
) -> list[str]:
291295
"""
292296
Get the Trainer container command from the given training function and parameters.
@@ -331,7 +335,7 @@ def get_command_using_train_func(
331335
if packages_to_install:
332336
install_packages = get_script_for_python_packages(
333337
packages_to_install,
334-
pip_index_url,
338+
pip_index_urls,
335339
is_mpi,
336340
)
337341

@@ -372,7 +376,7 @@ def get_trainer_crd_from_custom_trainer(
372376
runtime,
373377
trainer.func,
374378
trainer.func_args,
375-
trainer.pip_index_url,
379+
trainer.pip_index_urls,
376380
trainer.packages_to_install,
377381
)
378382

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025 The Kubeflow Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Any, Dict
17+
18+
import pytest
19+
20+
from kubeflow.trainer.utils import utils
21+
from kubeflow.trainer.constants import constants
22+
23+
24+
@dataclass
25+
class TestCase:
26+
name: str
27+
config: Dict[str, Any]
28+
expected_output: str
29+
__test__ = False
30+
31+
32+
@pytest.mark.parametrize(
33+
"test_case",
34+
[
35+
TestCase(
36+
name="multiple pip index URLs",
37+
config={
38+
"packages_to_install": ["torch", "numpy", "custom-package"],
39+
"pip_index_urls": [
40+
"https://pypi.org/simple",
41+
"https://private.repo.com/simple",
42+
"https://internal.company.com/simple"
43+
],
44+
"is_mpi": False
45+
},
46+
expected_output=(
47+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
48+
' python -m ensurepip || python -m ensurepip --user || '
49+
'apt-get install python-pip\n'
50+
'fi\n\n'
51+
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
52+
'--no-warn-script-location --index-url https://pypi.org/simple '
53+
'--extra-index-url https://private.repo.com/simple '
54+
'--extra-index-url https://internal.company.com/simple '
55+
'torch numpy custom-package\n'
56+
)
57+
),
58+
TestCase(
59+
name="single pip index URL (backward compatibility)",
60+
config={
61+
"packages_to_install": ["torch", "numpy", "custom-package"],
62+
"pip_index_urls": ["https://pypi.org/simple"],
63+
"is_mpi": False
64+
},
65+
expected_output=(
66+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
67+
' python -m ensurepip || python -m ensurepip --user || '
68+
'apt-get install python-pip\n'
69+
'fi\n\n'
70+
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
71+
'--no-warn-script-location --index-url https://pypi.org/simple '
72+
'torch numpy custom-package\n'
73+
)
74+
),
75+
TestCase(
76+
name="multiple pip index URLs with MPI",
77+
config={
78+
"packages_to_install": ["torch", "numpy", "custom-package"],
79+
"pip_index_urls": [
80+
"https://pypi.org/simple",
81+
"https://private.repo.com/simple",
82+
"https://internal.company.com/simple"
83+
],
84+
"is_mpi": True
85+
},
86+
expected_output=(
87+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
88+
' python -m ensurepip || python -m ensurepip --user || '
89+
'apt-get install python-pip\n'
90+
'fi\n\n'
91+
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
92+
'--no-warn-script-location --index-url https://pypi.org/simple '
93+
'--extra-index-url https://private.repo.com/simple '
94+
'--extra-index-url https://internal.company.com/simple '
95+
'--user torch numpy custom-package\n'
96+
)
97+
),
98+
TestCase(
99+
name="default pip index URLs",
100+
config={
101+
"packages_to_install": ["torch", "numpy"],
102+
"pip_index_urls": constants.DEFAULT_PIP_INDEX_URLS,
103+
"is_mpi": False
104+
},
105+
expected_output=(
106+
'\nif ! [ -x "$(command -v pip)" ]; then\n'
107+
' python -m ensurepip || python -m ensurepip --user || '
108+
'apt-get install python-pip\n'
109+
'fi\n\n'
110+
'PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet '
111+
f'--no-warn-script-location --index-url '
112+
f'{constants.DEFAULT_PIP_INDEX_URLS[0]} torch numpy\n'
113+
)
114+
),
115+
],
116+
)
117+
def test_get_script_for_python_packages(test_case):
118+
"""Test get_script_for_python_packages with various configurations."""
119+
120+
script = utils.get_script_for_python_packages(
121+
packages_to_install=test_case.config["packages_to_install"],
122+
pip_index_urls=test_case.config["pip_index_urls"],
123+
is_mpi=test_case.config["is_mpi"]
124+
)
125+
126+
assert test_case.expected_output == script

0 commit comments

Comments
 (0)