@@ -647,10 +647,6 @@ def _group_to_dag_spec(
647
647
raise NotImplementedError (
648
648
'dsl.graph_component is not yet supported in KFP v2 compiler.' )
649
649
650
- if isinstance (subgroup , dsl .OpsGroup ) and subgroup .type == 'exit_handler' :
651
- raise NotImplementedError (
652
- 'dsl.ExitHandler is not yet supported in KFP v2 compiler.' )
653
-
654
650
if isinstance (subgroup , dsl .ContainerOp ):
655
651
if hasattr (subgroup , 'importer_spec' ):
656
652
importer_task_name = subgroup .task_spec .task_info .name
@@ -909,8 +905,60 @@ def _create_pipeline_spec(
909
905
op_name_to_parent_groups ,
910
906
)
911
907
908
+ # Exit Handler
909
+ if pipeline .groups [0 ].groups :
910
+ first_group = pipeline .groups [0 ].groups [0 ]
911
+ if first_group .type == 'exit_handler' :
912
+ exit_handler_op = first_group .exit_op
913
+
914
+ # Add exit op task spec
915
+ task_name = exit_handler_op .task_spec .task_info .name
916
+ exit_handler_op .task_spec .dependent_tasks .extend (
917
+ pipeline_spec .root .dag .tasks .keys ())
918
+ exit_handler_op .task_spec .trigger_policy .strategy = (
919
+ pipeline_spec_pb2 .PipelineTaskSpec .TriggerPolicy .TriggerStrategy
920
+ .ALL_UPSTREAM_TASKS_COMPLETED )
921
+ pipeline_spec .root .dag .tasks [task_name ].CopyFrom (
922
+ exit_handler_op .task_spec )
923
+
924
+ # Add exit op component spec if it does not exist.
925
+ component_name = exit_handler_op .task_spec .component_ref .name
926
+ if component_name not in pipeline_spec .components :
927
+ pipeline_spec .components [component_name ].CopyFrom (
928
+ exit_handler_op .component_spec )
929
+
930
+ # Add exit op executor spec if it does not exist.
931
+ executor_label = exit_handler_op .component_spec .executor_label
932
+ if executor_label not in deployment_config .executors :
933
+ deployment_config .executors [executor_label ].container .CopyFrom (
934
+ exit_handler_op .container_spec )
935
+ pipeline_spec .deployment_spec .update (
936
+ json_format .MessageToDict (deployment_config ))
937
+
912
938
return pipeline_spec
913
939
940
+ def _validate_exit_handler (self , pipeline ):
941
+ """Makes sure there is only one global exit handler.
942
+
943
+ This is temporary to be compatible with KFP v1.
944
+ """
945
+
946
+ def _validate_exit_handler_helper (group , exiting_op_names , handler_exists ):
947
+ if group .type == 'exit_handler' :
948
+ if handler_exists or len (exiting_op_names ) > 1 :
949
+ raise ValueError (
950
+ 'Only one global exit_handler is allowed and all ops need to be included.'
951
+ )
952
+ handler_exists = True
953
+
954
+ if group .ops :
955
+ exiting_op_names .extend ([x .name for x in group .ops ])
956
+
957
+ for g in group .groups :
958
+ _validate_exit_handler_helper (g , exiting_op_names , handler_exists )
959
+
960
+ return _validate_exit_handler_helper (pipeline .groups [0 ], [], False )
961
+
914
962
# TODO: Sanitizing beforehand, so that we don't need to sanitize here.
915
963
def _sanitize_and_inject_artifact (self , pipeline : dsl .Pipeline ) -> None :
916
964
"""Sanitize operator/param names and inject pipeline artifact location. """
@@ -1006,6 +1054,7 @@ def _create_pipeline_v2(
1006
1054
with dsl .Pipeline (pipeline_name ) as dsl_pipeline :
1007
1055
pipeline_func (* args_list )
1008
1056
1057
+ self ._validate_exit_handler (dsl_pipeline )
1009
1058
self ._sanitize_and_inject_artifact (dsl_pipeline )
1010
1059
1011
1060
# Fill in the default values.
0 commit comments