Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Kubeflow Trainer client supports local development without needing a Kubernetes
### Available Backends

- **KubernetesBackend** (default) - Production training on Kubernetes
- **ContainerBackend** - Local development with Docker/Podman isolation
- **ContainerBackend** - Local development with Docker/Podman isolation
- **LocalProcessBackend** - Quick prototyping with Python subprocesses

**Quick Start:**
Expand Down
25 changes: 23 additions & 2 deletions kubeflow/trainer/backends/container/runtime_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _create_default_runtimes() -> list[base_types.Runtime]:
num_nodes=1,
),
pretrained_model=None,
image=image,
)
default_runtimes.append(runtime)
logger.debug(f"Created default runtime: {runtime.name} with image {image}")
Expand Down Expand Up @@ -385,8 +386,27 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t
)
node_spec = node_jobs[0].get("template", {}).get("spec", {}).get("template", {}).get("spec", {})
containers = node_spec.get("containers", [])
if not containers or not containers[0].get("image"):
raise ValueError(f"Runtime {name} from {source} 'node' must specify containers[0].image")
if not containers:
raise ValueError(f"Runtime {name} from {source} 'node' must specify at least one container")

# Extract the container image from the container named 'node', or fallback to first container
image = None
for container in containers:
if container.get("name") == "node" and container.get("image"):
image = container.get("image")
break

# Fallback to first container with an image if no 'node' container found
if not image:
for container in containers:
if container.get("image"):
image = container.get("image")
break

if not image:
raise ValueError(
f"Runtime {name} from {source} 'node' must specify an image in at least one container"
)

return base_types.Runtime(
name=name,
Expand All @@ -396,6 +416,7 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t
num_nodes=num_nodes,
),
pretrained_model=None,
image=image,
)


Expand Down
168 changes: 168 additions & 0 deletions kubeflow/trainer/backends/container/runtime_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def test_create_default_runtimes():
assert torch_runtimes[0].name == "torch-distributed"
assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER
assert torch_runtimes[0].trainer.num_nodes == 1
# Verify default image is set
assert torch_runtimes[0].image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"]
print("test execution complete")


Expand Down Expand Up @@ -524,3 +526,169 @@ def test_fetch_runtime_from_github(test_case):
except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="parse runtime yaml with custom image",
expected_status=SUCCESS,
config={
"custom_image": "quay.io/custom/pytorch-arm:v1.0",
"runtime_name": "torch-arm",
"framework": "torch",
"num_nodes": 2,
},
),
TestCase(
name="parse runtime yaml with different custom image",
expected_status=SUCCESS,
config={
"custom_image": "my-registry.io/pytorch:gpu-arm64",
"runtime_name": "torch-gpu-arm",
"framework": "torch",
"num_nodes": 4,
},
),
TestCase(
name="parse runtime yaml prefers container named node",
expected_status=SUCCESS,
config={
"custom_image": "correct-node-image:v1.0",
"runtime_name": "multi-container-runtime",
"framework": "torch",
"num_nodes": 1,
"multiple_containers": True,
},
),
],
)
def test_parse_runtime_yaml_extracts_image(test_case):
"""
Test that _parse_runtime_yaml correctly extracts and stores the container image.
This prevents regression of bugs where custom images are ignored.
"""
print("Executing test:", test_case.name)
try:
# Create container list based on test case
if test_case.config.get("multiple_containers"):
# Test case with multiple containers - should prefer 'node' container
containers = [
{
"name": "sidecar",
"image": "wrong-sidecar-image:v1.0",
},
{
"name": "node",
"image": test_case.config["custom_image"],
},
]
else:
# Single container test case
containers = [
{
"name": "trainer",
"image": test_case.config["custom_image"],
}
]

# Create runtime YAML with custom image
runtime_yaml = {
"kind": "ClusterTrainingRuntime",
"metadata": {
"name": test_case.config["runtime_name"],
"labels": {"trainer.kubeflow.org/framework": test_case.config["framework"]},
},
"spec": {
"mlPolicy": {"numNodes": test_case.config["num_nodes"]},
"template": {
"spec": {
"replicatedJobs": [
{
"name": "node",
"template": {
"spec": {"template": {"spec": {"containers": containers}}}
},
}
]
}
},
},
}

runtime = runtime_loader._parse_runtime_yaml(runtime_yaml, "test")

# Verify image is extracted and stored
assert runtime.image == test_case.config["custom_image"]
assert runtime.name == test_case.config["runtime_name"]
assert runtime.trainer.framework == test_case.config["framework"]
assert runtime.trainer.num_nodes == test_case.config["num_nodes"]

assert test_case.expected_status == SUCCESS

except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="resolve image uses custom image",
expected_status=SUCCESS,
config={
"custom_image": "my-registry.io/pytorch-custom:arm64",
"framework": "torch",
"expect_custom": True,
},
),
TestCase(
name="resolve image falls back to default when no custom image",
expected_status=SUCCESS,
config={
"custom_image": None,
"framework": "torch",
"expect_custom": False,
},
),
],
)
def test_resolve_image_uses_custom_image(test_case):
"""
Test that resolve_image prioritizes runtime.image over default framework images.
This ensures custom images from ClusterTrainingRuntimes are actually used.
"""
print("Executing test:", test_case.name)
try:
from kubeflow.trainer.backends.container import utils

# Create runtime with or without custom image
runtime = base_types.Runtime(
name="test-runtime",
trainer=base_types.RuntimeTrainer(
trainer_type=base_types.TrainerType.CUSTOM_TRAINER,
framework=test_case.config["framework"],
num_nodes=1,
),
image=test_case.config["custom_image"],
)

resolved_image = utils.resolve_image(runtime)

if test_case.config["expect_custom"]:
# Should use custom image
assert resolved_image == test_case.config["custom_image"]
else:
# Should fall back to default
assert (
resolved_image == constants.DEFAULT_FRAMEWORK_IMAGES[test_case.config["framework"]]
)
assert "pytorch/pytorch" in resolved_image

assert test_case.expected_status == SUCCESS

except Exception as e:
assert type(e) is test_case.expected_error
print("test execution complete")
11 changes: 10 additions & 1 deletion kubeflow/trainer/backends/container/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def aggregate_status_from_containers(container_statuses: list[str]) -> str:

def resolve_image(runtime: types.Runtime) -> str:
"""
Resolve the container image for a runtime from DEFAULT_FRAMEWORK_IMAGES.
Resolve the container image for a runtime.

Priority:
1. Use runtime.image if specified in the ClusterTrainingRuntime
2. Fall back to DEFAULT_FRAMEWORK_IMAGES based on framework

Args:
runtime: Runtime object.
Expand All @@ -163,6 +167,11 @@ def resolve_image(runtime: types.Runtime) -> str:
Raises:
ValueError: If no image is found for the runtime's framework.
"""
# Use image from runtime if specified
if runtime.image:
return runtime.image

# Fall back to default framework images
framework = runtime.trainer.framework
if framework in constants.DEFAULT_FRAMEWORK_IMAGES:
return constants.DEFAULT_FRAMEWORK_IMAGES[framework]
Expand Down
1 change: 1 addition & 0 deletions kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ class Runtime:
name: str
trainer: RuntimeTrainer
pretrained_model: Optional[str] = None
image: Optional[str] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti @Fiona-Waters, shall we move this image under RuntimeTrainer ?
Since we also have initializer in the Runtime.
Also, we might need to update Kubernetes backend to also populate this field.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich right that makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can look at creating a follow on PR to do that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you already got there :-)



# Representation for the TrainJob steps.
Expand Down
Loading