Skip to content

Commit a16fe4d

Browse files
feat(trainer): Add support for param unpacking in the training function call
Signed-off-by: Brian Gallagher <briangal@gmail.com>
1 parent 427b35d commit a16fe4d

File tree

6 files changed

+197
-49
lines changed

6 files changed

+197
-49
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ uv-venv:
6868
.PHONY: test-python
6969
test-python: uv-venv
7070
@uv sync
71-
@uv run coverage run --source=kubeflow.trainer.backends.kubernetes.backend,kubeflow.trainer.utils.utils -m pytest ./kubeflow/trainer/backends/kubernetes/backend_test.py
71+
@uv run coverage run --source=kubeflow.trainer.backends.kubernetes.backend,kubeflow.trainer.utils.utils -m pytest ./kubeflow/trainer/backends/kubernetes/backend_test.py ./kubeflow/trainer/utils/utils_test.py
7272
@uv run coverage report -m kubeflow/trainer/backends/kubernetes/backend.py kubeflow/trainer/utils/utils.py
7373
ifeq ($(report),xml)
7474
@uv run coverage xml

kubeflow/trainer/backends/kubernetes/backend_test.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@
2424
import random
2525
import string
2626
import uuid
27-
from dataclasses import asdict, dataclass, field
28-
from typing import Any, Optional, Type
27+
from dataclasses import asdict
28+
from typing import Optional
2929
from unittest.mock import Mock, patch
3030

3131
import pytest
@@ -34,28 +34,17 @@
3434
from kubeflow.trainer.constants import constants
3535
from kubeflow.trainer.types import types
3636
from kubeflow.trainer.utils import utils
37-
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
3837
from kubeflow.trainer.backends.kubernetes.backend import KubernetesBackend
38+
from kubeflow.trainer.backends.kubernetes.types import KubernetesBackendConfig
39+
from kubeflow.trainer.test.common import TestCase
40+
from kubeflow.trainer.test.common import (
41+
SUCCESS,
42+
FAILED,
43+
DEFAULT_NAMESPACE,
44+
TIMEOUT,
45+
RUNTIME,
46+
)
3947

40-
41-
@dataclass
42-
class TestCase:
43-
name: str
44-
expected_status: str
45-
config: dict[str, Any] = field(default_factory=dict)
46-
expected_output: Optional[Any] = None
47-
expected_error: Optional[Type[Exception]] = None
48-
__test__ = False
49-
50-
51-
# --------------------------
52-
# Constants for test scenarios
53-
# --------------------------
54-
TIMEOUT = "timeout"
55-
RUNTIME = "runtime"
56-
SUCCESS = "success"
57-
FAILED = "Failed"
58-
DEFAULT_NAMESPACE = "default"
5948
# In all tests runtime name is equal to the framework name.
6049
TORCH_RUNTIME = "torch"
6150
TORCH_TUNE_RUNTIME = "torchtune"
@@ -238,9 +227,9 @@ def get_custom_trainer(
238227
'\nif ! [ -x "$(command -v pip)" ]; then\n python -m ensurepip '
239228
"|| python -m ensurepip --user || apt-get install python-pip"
240229
"\nfi\n\nPIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet"
241-
f" --no-warn-script-location {pip_command} {packages_command}"
230+
f" --no-warn-script-location {pip_command} {packages_command}"
242231
"\n\nread -r -d '' SCRIPT << EOM\n\nfunc=lambda: "
243-
'print("Hello World"),\n\n<lambda>('
232+
'print("Hello World"),\n\n<lambda>(**'
244233
"{'learning_rate': 0.001, 'batch_size': 32})\n\nEOM\nprintf \"%s\" "
245234
'"$SCRIPT" > "backend_test.py"\ntorchrun "backend_test.py"',
246235
],

kubeflow/trainer/test/common.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Shared test utilities and types for Kubeflow Trainer tests.
2+
3+
from dataclasses import dataclass, field
4+
from typing import Any, Optional, Type
5+
6+
7+
# Common status constants
8+
SUCCESS = "success"
9+
FAILED = "Failed"
10+
DEFAULT_NAMESPACE = "default"
11+
TIMEOUT = "timeout"
12+
RUNTIME = "runtime"
13+
14+
15+
@dataclass
16+
class TestCase:
17+
name: str
18+
expected_status: str = SUCCESS
19+
config: dict[str, Any] = field(default_factory=dict)
20+
expected_output: Optional[Any] = None
21+
expected_error: Optional[Type[Exception]] = None
22+
# Prevent pytest from collecting this dataclass as a test
23+
__test__ = False

kubeflow/trainer/types/types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@ class CustomTrainer:
2929
3030
Args:
3131
func (`Callable`): The function that encapsulates the entire model training process.
32-
func_args (`Optional[Dict]`): The arguments to pass to the function.
33-
packages_to_install (`Optional[List[str]]`):
32+
func_args (`Optional[dict]`): The arguments to pass to the function.
33+
packages_to_install (`Optional[list[str]]`):
3434
A list of Python packages to install before running the function.
3535
pip_index_urls (`list[str]`): The PyPI URLs from which to install
3636
Python packages. The first URL will be the index-url, and remaining ones
3737
are extra-index-urls.
3838
num_nodes (`Optional[int]`): The number of nodes to use for training.
39-
resources_per_node (`Optional[Dict]`): The computing resources to allocate per node.
40-
env (`Optional[Dict[str, str]]`): The environment variables to set in the training nodes.
39+
resources_per_node (`Optional[dict]`): The computing resources to allocate per node.
40+
env (`Optional[dict[str, str]]`): The environment variables to set in the training nodes.
4141
"""
4242

4343
func: Callable

kubeflow/trainer/utils/utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
import inspect
1616
import os
1717
import textwrap
18-
from typing import Any, Callable, Optional
19-
from urllib.parse import urlparse
2018

19+
from typing import Callable, Optional, Any
20+
from urllib.parse import urlparse
2121
from kubeflow.trainer.constants import constants
2222
from kubeflow.trainer.types import types
2323
from kubeflow_trainer_api import models
@@ -268,15 +268,19 @@ def get_script_for_python_packages(
268268
if is_mpi:
269269
options.append("--user")
270270

271-
script_for_python_packages = textwrap.dedent(
271+
header_script = textwrap.dedent(
272272
"""
273273
if ! [ -x "$(command -v pip)" ]; then
274274
python -m ensurepip || python -m ensurepip --user || apt-get install python-pip
275275
fi
276276
277-
PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet \
278-
--no-warn-script-location {} {}
279-
""".format(
277+
"""
278+
)
279+
280+
script_for_python_packages = (
281+
header_script
282+
+ "PIP_DISABLE_PIP_VERSION_CHECK=1 python -m pip install --quiet "
283+
+ "--no-warn-script-location {} {}\n".format(
280284
" ".join(options),
281285
packages_str,
282286
)
@@ -318,12 +322,16 @@ def get_command_using_train_func(
318322
# Wrap function code to execute it from the file. For example:
319323
# TODO (andreyvelich): Find a better way to run users' scripts.
320324
# def train(parameters):
321-
# print('Start Training...')
325+
# print('Start Training...')
322326
# train({'lr': 0.01})
323327
if train_func_parameters is None:
324-
func_code = f"{func_code}\n{train_func.__name__}()\n"
328+
func_call = f"{train_func.__name__}()"
325329
else:
326-
func_code = f"{func_code}\n{train_func.__name__}({train_func_parameters})\n"
330+
# Always unpack kwargs for training function calls.
331+
func_call = f"{train_func.__name__}(**{train_func_parameters})"
332+
333+
# Combine everything into the final code string.
334+
func_code = f"{func_code}\n{func_call}\n"
327335

328336
is_mpi = runtime.trainer.command[0] == "mpirun"
329337
# The default file location for OpenMPI is: /home/mpiuser/<FILE_NAME>.py

kubeflow/trainer/utils/utils_test.py

Lines changed: 139 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dataclasses import dataclass
16-
from typing import Any, Dict
17-
1815
import pytest
1916

2017
from kubeflow.trainer.utils import utils
2118
from kubeflow.trainer.constants import constants
19+
from kubeflow.trainer.types import types
20+
from kubeflow.trainer.test.common import TestCase, SUCCESS, FAILED
2221

23-
24-
@dataclass
25-
class TestCase:
26-
name: str
27-
config: Dict[str, Any]
28-
expected_output: str
29-
__test__ = False
30-
22+
def _build_runtime() -> types.Runtime:
23+
runtime_trainer = types.RuntimeTrainer(
24+
trainer_type=types.TrainerType.CUSTOM_TRAINER,
25+
framework="torch",
26+
device="cpu",
27+
device_count="1",
28+
)
29+
runtime_trainer.set_command(constants.DEFAULT_COMMAND)
30+
return types.Runtime(name="test-runtime", trainer=runtime_trainer)
3131

3232
@pytest.mark.parametrize(
3333
"test_case",
@@ -124,3 +124,131 @@ def test_get_script_for_python_packages(test_case):
124124
)
125125

126126
assert test_case.expected_output == script
127+
128+
@pytest.mark.parametrize(
129+
"test_case",
130+
[
131+
TestCase(
132+
name="with args dict always unpacks kwargs",
133+
expected_status=SUCCESS,
134+
config={
135+
"func": (lambda: print("Hello World")),
136+
"func_args": {"batch_size": 128, "learning_rate": 0.001, "epochs": 20},
137+
"runtime": _build_runtime(),
138+
},
139+
expected_output=[
140+
'bash',
141+
'-c',
142+
(
143+
"\nread -r -d '' SCRIPT << EOM\n\n"
144+
'"func": (lambda: print("Hello World")),\n\n'
145+
"<lambda>(**{'batch_size': 128, 'learning_rate': 0.001, 'epochs': 20})\n\n"
146+
'EOM\n'
147+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
148+
'python "utils_test.py"'
149+
),
150+
]
151+
),
152+
TestCase(
153+
name="without args calls function with no params",
154+
expected_status=SUCCESS,
155+
config={
156+
"func": (lambda: print("Hello World")),
157+
"func_args": None,
158+
"runtime": _build_runtime(),
159+
},
160+
expected_output=[
161+
'bash',
162+
'-c',
163+
(
164+
"\nread -r -d '' SCRIPT << EOM\n\n"
165+
'"func": (lambda: print("Hello World")),\n\n'
166+
'<lambda>()\n\n'
167+
'EOM\n'
168+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
169+
'python "utils_test.py"'
170+
),
171+
],
172+
),
173+
TestCase(
174+
name="raises when runtime has no trainer",
175+
expected_status=FAILED,
176+
config={
177+
"func": (lambda: print("Hello World")),
178+
"func_args": None,
179+
"runtime": types.Runtime(name="no-trainer", trainer=None),
180+
},
181+
expected_error=ValueError,
182+
),
183+
TestCase(
184+
name="raises when train_func is not callable",
185+
expected_status=FAILED,
186+
config={
187+
"func": "not callable",
188+
"func_args": None,
189+
"runtime": _build_runtime(),
190+
},
191+
expected_error=ValueError,
192+
),
193+
TestCase(
194+
name="single dict param also unpacks kwargs",
195+
expected_status=SUCCESS,
196+
config={
197+
"func": (lambda: print("Hello World")),
198+
"func_args": {"a": 1, "b": 2},
199+
"runtime": _build_runtime(),
200+
},
201+
expected_output=[
202+
'bash',
203+
'-c',
204+
(
205+
"\nread -r -d '' SCRIPT << EOM\n\n"
206+
'"func": (lambda: print("Hello World")),\n\n'
207+
"<lambda>(**{'a': 1, 'b': 2})\n\n"
208+
'EOM\n'
209+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
210+
'python "utils_test.py"'
211+
),
212+
],
213+
),
214+
TestCase(
215+
name="multi-param function uses kwargs-unpacking",
216+
expected_status=SUCCESS,
217+
config={
218+
"func": (lambda **kwargs: "ok"),
219+
"func_args": {"a": 3, "b": "hi", "c": 0.2},
220+
"runtime": _build_runtime(),
221+
},
222+
expected_output=[
223+
"bash",
224+
"-c",
225+
(
226+
"\nread -r -d '' SCRIPT << EOM\n\n"
227+
'"func": (lambda **kwargs: "ok"),\n\n'
228+
"<lambda>(**{'a': 3, 'b': 'hi', 'c': 0.2})\n\n"
229+
'EOM\n'
230+
'printf "%s" "$SCRIPT" > "utils_test.py"\n'
231+
'python "utils_test.py"'
232+
),
233+
],
234+
),
235+
],
236+
)
237+
def test_get_command_using_train_func(test_case: TestCase):
238+
print("Executing test:", test_case.name)
239+
240+
try:
241+
command = utils.get_command_using_train_func(
242+
runtime=test_case.config["runtime"],
243+
train_func=test_case.config.get("func"),
244+
train_func_parameters=test_case.config.get("func_args"),
245+
pip_index_urls=constants.DEFAULT_PIP_INDEX_URLS,
246+
packages_to_install=[],
247+
)
248+
249+
assert test_case.expected_status == SUCCESS
250+
assert command == test_case.expected_output
251+
252+
except Exception as e:
253+
assert type(e) is test_case.expected_error
254+
print("test execution complete")

0 commit comments

Comments
 (0)