diff --git a/README.md b/README.md index 7e94d653..0088b427 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Kubeflow Trainer client supports local development without needing a Kubernetes ### Available Backends - **KubernetesBackend** (default) - Production training on Kubernetes -- **ContainerBackend** - Local development with Docker/Podman isolation +- **ContainerBackend** - Local development with Docker/Podman isolation - **LocalProcessBackend** - Quick prototyping with Python subprocesses **Quick Start:** diff --git a/kubeflow/trainer/backends/container/runtime_loader.py b/kubeflow/trainer/backends/container/runtime_loader.py index 3e614400..b20d7ba4 100644 --- a/kubeflow/trainer/backends/container/runtime_loader.py +++ b/kubeflow/trainer/backends/container/runtime_loader.py @@ -329,6 +329,7 @@ def _create_default_runtimes() -> list[base_types.Runtime]: num_nodes=1, ), pretrained_model=None, + image=image, ) default_runtimes.append(runtime) logger.debug(f"Created default runtime: {runtime.name} with image {image}") @@ -385,8 +386,27 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t ) node_spec = node_jobs[0].get("template", {}).get("spec", {}).get("template", {}).get("spec", {}) containers = node_spec.get("containers", []) - if not containers or not containers[0].get("image"): - raise ValueError(f"Runtime {name} from {source} 'node' must specify containers[0].image") + if not containers: + raise ValueError(f"Runtime {name} from {source} 'node' must specify at least one container") + + # Extract the container image from the container named 'node', or fallback to first container + image = None + for container in containers: + if container.get("name") == "node" and container.get("image"): + image = container.get("image") + break + + # Fallback to first container with an image if no 'node' container found + if not image: + for container in containers: + if container.get("image"): + image = container.get("image") + break + + if not image: + raise ValueError( + f"Runtime {name} from {source} 'node' must specify an image in at least one container" + ) return base_types.Runtime( name=name, @@ -396,6 +416,7 @@ def _parse_runtime_yaml(data: dict[str, Any], source: str = "unknown") -> base_t num_nodes=num_nodes, ), pretrained_model=None, + image=image, ) diff --git a/kubeflow/trainer/backends/container/runtime_loader_test.py b/kubeflow/trainer/backends/container/runtime_loader_test.py index ea0ec2cf..a62fef92 100644 --- a/kubeflow/trainer/backends/container/runtime_loader_test.py +++ b/kubeflow/trainer/backends/container/runtime_loader_test.py @@ -357,6 +357,8 @@ def test_create_default_runtimes(): assert torch_runtimes[0].name == "torch-distributed" assert torch_runtimes[0].trainer.trainer_type == base_types.TrainerType.CUSTOM_TRAINER assert torch_runtimes[0].trainer.num_nodes == 1 + # Verify default image is set + assert torch_runtimes[0].image == constants.DEFAULT_FRAMEWORK_IMAGES["torch"] print("test execution complete") @@ -524,3 +526,169 @@ def test_fetch_runtime_from_github(test_case): except Exception as e: assert type(e) is test_case.expected_error print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="parse runtime yaml with custom image", + expected_status=SUCCESS, + config={ + "custom_image": "quay.io/custom/pytorch-arm:v1.0", + "runtime_name": "torch-arm", + "framework": "torch", + "num_nodes": 2, + }, + ), + TestCase( + name="parse runtime yaml with different custom image", + expected_status=SUCCESS, + config={ + "custom_image": "my-registry.io/pytorch:gpu-arm64", + "runtime_name": "torch-gpu-arm", + "framework": "torch", + "num_nodes": 4, + }, + ), + TestCase( + name="parse runtime yaml prefers container named node", + expected_status=SUCCESS, + config={ + "custom_image": "correct-node-image:v1.0", + "runtime_name": "multi-container-runtime", + "framework": "torch", + "num_nodes": 1, + "multiple_containers": True, + }, + ), + ], +) +def test_parse_runtime_yaml_extracts_image(test_case): + """ + Test that _parse_runtime_yaml correctly extracts and stores the container image. + This prevents regression of bugs where custom images are ignored. + """ + print("Executing test:", test_case.name) + try: + # Create container list based on test case + if test_case.config.get("multiple_containers"): + # Test case with multiple containers - should prefer 'node' container + containers = [ + { + "name": "sidecar", + "image": "wrong-sidecar-image:v1.0", + }, + { + "name": "node", + "image": test_case.config["custom_image"], + }, + ] + else: + # Single container test case + containers = [ + { + "name": "trainer", + "image": test_case.config["custom_image"], + } + ] + + # Create runtime YAML with custom image + runtime_yaml = { + "kind": "ClusterTrainingRuntime", + "metadata": { + "name": test_case.config["runtime_name"], + "labels": {"trainer.kubeflow.org/framework": test_case.config["framework"]}, + }, + "spec": { + "mlPolicy": {"numNodes": test_case.config["num_nodes"]}, + "template": { + "spec": { + "replicatedJobs": [ + { + "name": "node", + "template": { + "spec": {"template": {"spec": {"containers": containers}}} + }, + } + ] + } + }, + }, + } + + runtime = runtime_loader._parse_runtime_yaml(runtime_yaml, "test") + + # Verify image is extracted and stored + assert runtime.image == test_case.config["custom_image"] + assert runtime.name == test_case.config["runtime_name"] + assert runtime.trainer.framework == test_case.config["framework"] + assert runtime.trainer.num_nodes == test_case.config["num_nodes"] + + assert test_case.expected_status == SUCCESS + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") + + +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + name="resolve image uses custom image", + expected_status=SUCCESS, + config={ + "custom_image": "my-registry.io/pytorch-custom:arm64", + "framework": "torch", + "expect_custom": True, + }, + ), + TestCase( + name="resolve image falls back to default when no custom image", + expected_status=SUCCESS, + config={ + "custom_image": None, + "framework": "torch", + "expect_custom": False, + }, + ), + ], +) +def test_resolve_image_uses_custom_image(test_case): + """ + Test that resolve_image prioritizes runtime.image over default framework images. + This ensures custom images from ClusterTrainingRuntimes are actually used. + """ + print("Executing test:", test_case.name) + try: + from kubeflow.trainer.backends.container import utils + + # Create runtime with or without custom image + runtime = base_types.Runtime( + name="test-runtime", + trainer=base_types.RuntimeTrainer( + trainer_type=base_types.TrainerType.CUSTOM_TRAINER, + framework=test_case.config["framework"], + num_nodes=1, + ), + image=test_case.config["custom_image"], + ) + + resolved_image = utils.resolve_image(runtime) + + if test_case.config["expect_custom"]: + # Should use custom image + assert resolved_image == test_case.config["custom_image"] + else: + # Should fall back to default + assert ( + resolved_image == constants.DEFAULT_FRAMEWORK_IMAGES[test_case.config["framework"]] + ) + assert "pytorch/pytorch" in resolved_image + + assert test_case.expected_status == SUCCESS + + except Exception as e: + assert type(e) is test_case.expected_error + print("test execution complete") diff --git a/kubeflow/trainer/backends/container/utils.py b/kubeflow/trainer/backends/container/utils.py index 8642f865..684d5559 100644 --- a/kubeflow/trainer/backends/container/utils.py +++ b/kubeflow/trainer/backends/container/utils.py @@ -152,7 +152,11 @@ def aggregate_status_from_containers(container_statuses: list[str]) -> str: def resolve_image(runtime: types.Runtime) -> str: """ - Resolve the container image for a runtime from DEFAULT_FRAMEWORK_IMAGES. + Resolve the container image for a runtime. + + Priority: + 1. Use runtime.image if specified in the ClusterTrainingRuntime + 2. Fall back to DEFAULT_FRAMEWORK_IMAGES based on framework Args: runtime: Runtime object. @@ -163,6 +167,11 @@ def resolve_image(runtime: types.Runtime) -> str: Raises: ValueError: If no image is found for the runtime's framework. """ + # Use image from runtime if specified + if runtime.image: + return runtime.image + + # Fall back to default framework images framework = runtime.trainer.framework if framework in constants.DEFAULT_FRAMEWORK_IMAGES: return constants.DEFAULT_FRAMEWORK_IMAGES[framework] diff --git a/kubeflow/trainer/types/types.py b/kubeflow/trainer/types/types.py index cb666985..a2b5f5fb 100644 --- a/kubeflow/trainer/types/types.py +++ b/kubeflow/trainer/types/types.py @@ -250,6 +250,7 @@ class Runtime: name: str trainer: RuntimeTrainer pretrained_model: Optional[str] = None + image: Optional[str] = None # Representation for the TrainJob steps.