Skip to content

Commit b87b81b

Browse files
authored
fix: Support custom images in ClusterTrainingRuntime for container backend (#140)
Signed-off-by: Fiona Waters <fiwaters6@gmail.com>
1 parent ed4600d commit b87b81b

File tree

4 files changed

+202
-3
lines changed

4 files changed

+202
-3
lines changed

kubeflow/trainer/backends/container/runtime_loader.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def _create_default_runtimes() -> list[base_types.Runtime]:
329329
num_nodes=1,
330330
),
331331
pretrained_model=None,
332+
image=image,
332333
)
333334
default_runtimes.append(runtime)
334335
logger.debug(f"Created default runtime: {runtime.name} with image {image}")
@@ -385,8 +386,27 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t
385386
)
386387
node_spec = node_jobs[0].get("template", {}).get("spec", {}).get("template", {}).get("spec", {})
387388
containers = node_spec.get("containers", [])
388-
if not containers or not containers[0].get("image"):
389-
raise ValueError(f"Runtime {name} from {source} 'node' must specify containers[0].image")
389+
if not containers:
390+
raise ValueError(f"Runtime {name} from {source} 'node' must specify at least one container")
391+
392+
# Extract the container image from the container named 'node', or fallback to first container
393+
image = None
394+
for container in containers:
395+
if container.get("name") == "node" and container.get("image"):
396+
image = container.get("image")
397+
break
398+
399+
# Fallback to first container with an image if no 'node' container found
400+
if not image:
401+
for container in containers:
402+
if container.get("image"):
403+
image = container.get("image")
404+
break
405+
406+
if not image:
407+
raise ValueError(
408+
f"Runtime {name} from {source} 'node' must specify an image in at least one container"
409+
)
390410

391411
return base_types.Runtime(
392412
name=name,
@@ -396,6 +416,7 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t
396416
num_nodes=num_nodes,
397417
),
398418
pretrained_model=None,
419+
image=image,
399420
)
400421

401422

kubeflow/trainer/backends/container/runtime_loader_test.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ def test_create_default_runtimes():
357357
assert torch_runtimes[0].name == "torch-distributed"
358358
assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER
359359
assert torch_runtimes[0].trainer.num_nodes == 1
360+
# Verify default image is set
361+
assert torch_runtimes[0].image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"]
360362
print("test execution complete")
361363

362364

@@ -524,3 +526,169 @@ def test_fetch_runtime_from_github(test_case):
524526
except Exception as e:
525527
assert type(e) is test_case.expected_error
526528
print("test execution complete")
529+
530+
531+
@pytest.mark.parametrize(
532+
"test_case",
533+
[
534+
TestCase(
535+
name="parse runtime yaml with custom image",
536+
expected_status=SUCCESS,
537+
config={
538+
"custom_image": "quay.io/custom/pytorch-arm:v1.0",
539+
"runtime_name": "torch-arm",
540+
"framework": "torch",
541+
"num_nodes": 2,
542+
},
543+
),
544+
TestCase(
545+
name="parse runtime yaml with different custom image",
546+
expected_status=SUCCESS,
547+
config={
548+
"custom_image": "my-registry.io/pytorch:gpu-arm64",
549+
"runtime_name": "torch-gpu-arm",
550+
"framework": "torch",
551+
"num_nodes": 4,
552+
},
553+
),
554+
TestCase(
555+
name="parse runtime yaml prefers container named node",
556+
expected_status=SUCCESS,
557+
config={
558+
"custom_image": "correct-node-image:v1.0",
559+
"runtime_name": "multi-container-runtime",
560+
"framework": "torch",
561+
"num_nodes": 1,
562+
"multiple_containers": True,
563+
},
564+
),
565+
],
566+
)
567+
def test_parse_runtime_yaml_extracts_image(test_case):
568+
"""
569+
Test that _parse_runtime_yaml correctly extracts and stores the container image.
570+
This prevents regression of bugs where custom images are ignored.
571+
"""
572+
print("Executing test:", test_case.name)
573+
try:
574+
# Create container list based on test case
575+
if test_case.config.get("multiple_containers"):
576+
# Test case with multiple containers - should prefer 'node' container
577+
containers = [
578+
{
579+
"name": "sidecar",
580+
"image": "wrong-sidecar-image:v1.0",
581+
},
582+
{
583+
"name": "node",
584+
"image": test_case.config["custom_image"],
585+
},
586+
]
587+
else:
588+
# Single container test case
589+
containers = [
590+
{
591+
"name": "trainer",
592+
"image": test_case.config["custom_image"],
593+
}
594+
]
595+
596+
# Create runtime YAML with custom image
597+
runtime_yaml = {
598+
"kind": "ClusterTrainingRuntime",
599+
"metadata": {
600+
"name": test_case.config["runtime_name"],
601+
"labels": {"trainer.kubeflow.org/framework": test_case.config["framework"]},
602+
},
603+
"spec": {
604+
"mlPolicy": {"numNodes": test_case.config["num_nodes"]},
605+
"template": {
606+
"spec": {
607+
"replicatedJobs": [
608+
{
609+
"name": "node",
610+
"template": {
611+
"spec": {"template": {"spec": {"containers": containers}}}
612+
},
613+
}
614+
]
615+
}
616+
},
617+
},
618+
}
619+
620+
runtime = runtime_loader._parse_runtime_yaml(runtime_yaml, "test")
621+
622+
# Verify image is extracted and stored
623+
assert runtime.image == test_case.config["custom_image"]
624+
assert runtime.name == test_case.config["runtime_name"]
625+
assert runtime.trainer.framework == test_case.config["framework"]
626+
assert runtime.trainer.num_nodes == test_case.config["num_nodes"]
627+
628+
assert test_case.expected_status == SUCCESS
629+
630+
except Exception as e:
631+
assert type(e) is test_case.expected_error
632+
print("test execution complete")
633+
634+
635+
@pytest.mark.parametrize(
636+
"test_case",
637+
[
638+
TestCase(
639+
name="resolve image uses custom image",
640+
expected_status=SUCCESS,
641+
config={
642+
"custom_image": "my-registry.io/pytorch-custom:arm64",
643+
"framework": "torch",
644+
"expect_custom": True,
645+
},
646+
),
647+
TestCase(
648+
name="resolve image falls back to default when no custom image",
649+
expected_status=SUCCESS,
650+
config={
651+
"custom_image": None,
652+
"framework": "torch",
653+
"expect_custom": False,
654+
},
655+
),
656+
],
657+
)
658+
def test_resolve_image_uses_custom_image(test_case):
659+
"""
660+
Test that resolve_image prioritizes runtime.image over default framework images.
661+
This ensures custom images from ClusterTrainingRuntimes are actually used.
662+
"""
663+
print("Executing test:", test_case.name)
664+
try:
665+
from kubeflow.trainer.backends.container import utils
666+
667+
# Create runtime with or without custom image
668+
runtime = base_types.Runtime(
669+
name="test-runtime",
670+
trainer=base_types.RuntimeTrainer(
671+
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
672+
framework=test_case.config["framework"],
673+
num_nodes=1,
674+
),
675+
image=test_case.config["custom_image"],
676+
)
677+
678+
resolved_image = utils.resolve_image(runtime)
679+
680+
if test_case.config["expect_custom"]:
681+
# Should use custom image
682+
assert resolved_image == test_case.config["custom_image"]
683+
else:
684+
# Should fall back to default
685+
assert (
686+
resolved_image == constants.DEFAULT_FRAMEWORK_IMAGES[test_case.config["framework"]]
687+
)
688+
assert "pytorch/pytorch" in resolved_image
689+
690+
assert test_case.expected_status == SUCCESS
691+
692+
except Exception as e:
693+
assert type(e) is test_case.expected_error
694+
print("test execution complete")

kubeflow/trainer/backends/container/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ def aggregate_status_from_containers(container_statuses: list[str]) -> str:
152152

153153
def resolve_image(runtime: types.Runtime) -> str:
154154
"""
155-
Resolve the container image for a runtime from DEFAULT_FRAMEWORK_IMAGES.
155+
Resolve the container image for a runtime.
156+
157+
Priority:
158+
1. Use runtime.image if specified in the ClusterTrainingRuntime
159+
2. Fall back to DEFAULT_FRAMEWORK_IMAGES based on framework
156160
157161
Args:
158162
runtime: Runtime object.
@@ -163,6 +167,11 @@ def resolve_image(runtime: types.Runtime) -> str:
163167
Raises:
164168
ValueError: If no image is found for the runtime's framework.
165169
"""
170+
# Use image from runtime if specified
171+
if runtime.image:
172+
return runtime.image
173+
174+
# Fall back to default framework images
166175
framework = runtime.trainer.framework
167176
if framework in constants.DEFAULT_FRAMEWORK_IMAGES:
168177
return constants.DEFAULT_FRAMEWORK_IMAGES[framework]

kubeflow/trainer/types/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ class Runtime:
250250
name: str
251251
trainer: RuntimeTrainer
252252
pretrained_model: Optional[str] = None
253+
image: Optional[str] = None
253254

254255

255256
# Representation for the TrainJob steps.

0 commit comments

Comments
 (0)