diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index c470ff3b62a..f637364d94a 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -686,6 +686,17 @@ def _create_pipeline_workflow(self, args, pipeline, op_transformers=None, pipeli if exit_handler: workflow['spec']['onExit'] = exit_handler.name + + if pipeline_conf.image_pull_policy != None: + if pipeline_conf.image_pull_policy in ["Always", "Never", "IfNotPresent"]: + for template in workflow["spec"]["templates"]: + container = template.get('container', None) + if container and "imagePullPolicy" not in container: + container["imagePullPolicy"] = pipeline_conf.image_pull_policy + else: + raise ValueError( + 'Invalid imagePullPolicy. Must be one of `Always`, `Never`, `IfNotPresent`.' + ) return workflow def _validate_exit_handler(self, pipeline): diff --git a/sdk/python/kfp/dsl/_pipeline.py b/sdk/python/kfp/dsl/_pipeline.py index 13ff94e9bdd..e6328a5f3aa 100644 --- a/sdk/python/kfp/dsl/_pipeline.py +++ b/sdk/python/kfp/dsl/_pipeline.py @@ -61,6 +61,7 @@ def __init__(self): self.timeout = 0 self.ttl_seconds_after_finished = -1 self.op_transformers = [] + self.image_pull_policy = None def set_image_pull_secrets(self, image_pull_secrets): """Configures the pipeline level imagepullsecret @@ -91,6 +92,16 @@ def set_ttl_seconds_after_finished(self, seconds: int): self.ttl_seconds_after_finished = seconds return self + def set_image_pull_policy(self, policy: str): + """Configures the default image pull policy + + Args: + policy: the pull policy, has to be one of: Always, Never, IfNotPresent. + For more info: https://github.com/kubernetes-client/python/blob/10a7f95435c0b94a6d949ba98375f8cc85a70e5a/kubernetes/docs/V1Container.md + """ + self.image_pull_policy = policy + return self + def add_op_transformer(self, transformer): """Configures the op_transformers which will be applied to all ops in the pipeline. @@ -218,5 +229,3 @@ def _set_metadata(self, metadata): metadata (ComponentMeta): component metadata ''' self._metadata = metadata - - diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index 7f4b8a8a114..6d3671fe341 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -678,6 +678,76 @@ def some_pipeline(): container = template.get('container', None) if container: self.assertEqual(template['retryStrategy']['limit'], 5) + + def test_image_pull_policy(self): + def some_op(): + return dsl.ContainerOp( + name='sleep', + image='busybox', + command=['sleep 1'], + ) + + @dsl.pipeline(name='some_pipeline') + def some_pipeline(): + task1 = some_op() + task2 = some_op() + task3 = some_op() + + dsl.get_pipeline_conf().set_image_pull_policy(policy="Always") + workflow_dict = compiler.Compiler()._compile(some_pipeline) + for template in workflow_dict['spec']['templates']: + container = template.get('container', None) + if container: + self.assertEqual(template['container']['imagePullPolicy'], "Always") + + + def test_image_pull_policy_step_spec(self): + def some_op(): + return dsl.ContainerOp( + name='sleep', + image='busybox', + command=['sleep 1'], + ) + + def some_other_op(): + return dsl.ContainerOp( + name='other', + image='busybox', + command=['sleep 1'], + ) + + @dsl.pipeline(name='some_pipeline') + def some_pipeline(): + task1 = some_op() + task2 = some_op() + task3 = some_other_op().set_image_pull_policy("IfNotPresent") + + dsl.get_pipeline_conf().set_image_pull_policy(policy="Always") + workflow_dict = compiler.Compiler()._compile(some_pipeline) + for template in workflow_dict['spec']['templates']: + container = template.get('container', None) + if container: + if template['name' ] == "other": + self.assertEqual(template['container']['imagePullPolicy'], "IfNotPresent") + elif template['name' ] == "sleep": + self.assertEqual(template['container']['imagePullPolicy'], "Always") + + def test_image_pull_policy_invalid_setting(self): + def some_op(): + return dsl.ContainerOp( + name='sleep', + image='busybox', + command=['sleep 1'], + ) + + with self.assertRaises(ValueError): + @dsl.pipeline(name='some_pipeline') + def some_pipeline(): + task1 = some_op() + task2 = some_op() + dsl.get_pipeline_conf().set_image_pull_policy(policy="Alwayss") + + workflow_dict = compiler.Compiler()._compile(some_pipeline) def test_container_op_output_error_when_no_or_multiple_outputs(self):