@@ -357,6 +357,8 @@ def test_create_default_runtimes():
357357 assert torch_runtimes [0 ].name == "torch-distributed"
358358 assert torch_runtimes [0 ].trainer .trainer_type == base_types .TrainerType .CUSTOM_TRAINER
359359 assert torch_runtimes [0 ].trainer .num_nodes == 1
360+ # Verify default image is set
361+ assert torch_runtimes [0 ].image == constants .DEFAULT_FRAMEWORK_IMAGES ["torch" ]
360362 print ("test execution complete" )
361363
362364
@@ -524,3 +526,169 @@ def test_fetch_runtime_from_github(test_case):
524526 except Exception as e :
525527 assert type (e ) is test_case .expected_error
526528 print ("test execution complete" )
529+
530+
531+ @pytest .mark .parametrize (
532+ "test_case" ,
533+ [
534+ TestCase (
535+ name = "parse runtime yaml with custom image" ,
536+ expected_status = SUCCESS ,
537+ config = {
538+ "custom_image" : "quay.io/custom/pytorch-arm:v1.0" ,
539+ "runtime_name" : "torch-arm" ,
540+ "framework" : "torch" ,
541+ "num_nodes" : 2 ,
542+ },
543+ ),
544+ TestCase (
545+ name = "parse runtime yaml with different custom image" ,
546+ expected_status = SUCCESS ,
547+ config = {
548+ "custom_image" : "my-registry.io/pytorch:gpu-arm64" ,
549+ "runtime_name" : "torch-gpu-arm" ,
550+ "framework" : "torch" ,
551+ "num_nodes" : 4 ,
552+ },
553+ ),
554+ TestCase (
555+ name = "parse runtime yaml prefers container named node" ,
556+ expected_status = SUCCESS ,
557+ config = {
558+ "custom_image" : "correct-node-image:v1.0" ,
559+ "runtime_name" : "multi-container-runtime" ,
560+ "framework" : "torch" ,
561+ "num_nodes" : 1 ,
562+ "multiple_containers" : True ,
563+ },
564+ ),
565+ ],
566+ )
567+ def test_parse_runtime_yaml_extracts_image (test_case ):
568+ """
569+ Test that _parse_runtime_yaml correctly extracts and stores the container image.
570+ This prevents regression of bugs where custom images are ignored.
571+ """
572+ print ("Executing test:" , test_case .name )
573+ try :
574+ # Create container list based on test case
575+ if test_case .config .get ("multiple_containers" ):
576+ # Test case with multiple containers - should prefer 'node' container
577+ containers = [
578+ {
579+ "name" : "sidecar" ,
580+ "image" : "wrong-sidecar-image:v1.0" ,
581+ },
582+ {
583+ "name" : "node" ,
584+ "image" : test_case .config ["custom_image" ],
585+ },
586+ ]
587+ else :
588+ # Single container test case
589+ containers = [
590+ {
591+ "name" : "trainer" ,
592+ "image" : test_case .config ["custom_image" ],
593+ }
594+ ]
595+
596+ # Create runtime YAML with custom image
597+ runtime_yaml = {
598+ "kind" : "ClusterTrainingRuntime" ,
599+ "metadata" : {
600+ "name" : test_case .config ["runtime_name" ],
601+ "labels" : {"trainer.kubeflow.org/framework" : test_case .config ["framework" ]},
602+ },
603+ "spec" : {
604+ "mlPolicy" : {"numNodes" : test_case .config ["num_nodes" ]},
605+ "template" : {
606+ "spec" : {
607+ "replicatedJobs" : [
608+ {
609+ "name" : "node" ,
610+ "template" : {
611+ "spec" : {"template" : {"spec" : {"containers" : containers }}}
612+ },
613+ }
614+ ]
615+ }
616+ },
617+ },
618+ }
619+
620+ runtime = runtime_loader ._parse_runtime_yaml (runtime_yaml , "test" )
621+
622+ # Verify image is extracted and stored
623+ assert runtime .image == test_case .config ["custom_image" ]
624+ assert runtime .name == test_case .config ["runtime_name" ]
625+ assert runtime .trainer .framework == test_case .config ["framework" ]
626+ assert runtime .trainer .num_nodes == test_case .config ["num_nodes" ]
627+
628+ assert test_case .expected_status == SUCCESS
629+
630+ except Exception as e :
631+ assert type (e ) is test_case .expected_error
632+ print ("test execution complete" )
633+
634+
635+ @pytest .mark .parametrize (
636+ "test_case" ,
637+ [
638+ TestCase (
639+ name = "resolve image uses custom image" ,
640+ expected_status = SUCCESS ,
641+ config = {
642+ "custom_image" : "my-registry.io/pytorch-custom:arm64" ,
643+ "framework" : "torch" ,
644+ "expect_custom" : True ,
645+ },
646+ ),
647+ TestCase (
648+ name = "resolve image falls back to default when no custom image" ,
649+ expected_status = SUCCESS ,
650+ config = {
651+ "custom_image" : None ,
652+ "framework" : "torch" ,
653+ "expect_custom" : False ,
654+ },
655+ ),
656+ ],
657+ )
658+ def test_resolve_image_uses_custom_image (test_case ):
659+ """
660+ Test that resolve_image prioritizes runtime.image over default framework images.
661+ This ensures custom images from ClusterTrainingRuntimes are actually used.
662+ """
663+ print ("Executing test:" , test_case .name )
664+ try :
665+ from kubeflow .trainer .backends .container import utils
666+
667+ # Create runtime with or without custom image
668+ runtime = base_types .Runtime (
669+ name = "test-runtime" ,
670+ trainer = base_types .RuntimeTrainer (
671+ trainer_type = base_types .TrainerType .CUSTOM_TRAINER ,
672+ framework = test_case .config ["framework" ],
673+ num_nodes = 1 ,
674+ ),
675+ image = test_case .config ["custom_image" ],
676+ )
677+
678+ resolved_image = utils .resolve_image (runtime )
679+
680+ if test_case .config ["expect_custom" ]:
681+ # Should use custom image
682+ assert resolved_image == test_case .config ["custom_image" ]
683+ else :
684+ # Should fall back to default
685+ assert (
686+ resolved_image == constants .DEFAULT_FRAMEWORK_IMAGES [test_case .config ["framework" ]]
687+ )
688+ assert "pytorch/pytorch" in resolved_image
689+
690+ assert test_case .expected_status == SUCCESS
691+
692+ except Exception as e :
693+ assert type (e ) is test_case .expected_error
694+ print ("test execution complete" )
0 commit comments