diff --git a/examples/build_api/sequence.py b/examples/build_api/sequence.py index 29eff72b..b25c4f0b 100644 --- a/examples/build_api/sequence.py +++ b/examples/build_api/sequence.py @@ -44,7 +44,6 @@ def forward(self, x): export.ExportPytorchModel(), export.OptimizeOnnxModel(), # export.ConvertOnnxToFp16(), #<-- This is the step we want to skip - export.SuccessStage(), ], enable_model_validation=True, ) diff --git a/examples/cli/plugins/example_combined/turnkeyml_plugin_example_combined/sequence.py b/examples/cli/plugins/example_combined/turnkeyml_plugin_example_combined/sequence.py index 55ba2083..70f2d1ca 100644 --- a/examples/cli/plugins/example_combined/turnkeyml_plugin_example_combined/sequence.py +++ b/examples/cli/plugins/example_combined/turnkeyml_plugin_example_combined/sequence.py @@ -27,7 +27,6 @@ def fire(self, state: build.State): stages=[ export.ExportPlaceholder(), CombinedExampleStage(), - export.SuccessStage(), ], enable_model_validation=True, ) diff --git a/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py b/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py index 350f2a76..d3ef72f2 100644 --- a/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py +++ b/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py @@ -43,7 +43,6 @@ def fire(self, state: build.State): stages=[ export.ExportPlaceholder(), ExampleStage(), - export.SuccessStage(), ], enable_model_validation=True, ) diff --git a/src/turnkeyml/build/export.py b/src/turnkeyml/build/export.py index 6d32f052..eee8a306 100644 --- a/src/turnkeyml/build/export.py +++ b/src/turnkeyml/build/export.py @@ -620,23 +620,3 @@ def fire(self, state: build.State): raise exp.StageError(msg) return state - - -class SuccessStage(stage.Stage): - """ - Stage that sets state.build_status = build.Status.SUCCESSFUL_BUILD, - indicating that the build sequence has completed all of the requested build stages. - """ - - def __init__(self): - super().__init__( - unique_name="set_success", - monitor_message="Finishing up", - ) - - def fire(self, state: build.State): - state.build_status = build.Status.SUCCESSFUL_BUILD - - state.results = copy.deepcopy(state.intermediate_results) - - return state diff --git a/src/turnkeyml/build/sequences.py b/src/turnkeyml/build/sequences.py index 7e90ead3..abeb159c 100644 --- a/src/turnkeyml/build/sequences.py +++ b/src/turnkeyml/build/sequences.py @@ -9,7 +9,6 @@ export.ExportPlaceholder(), export.OptimizeOnnxModel(), export.ConvertOnnxToFp16(), - export.SuccessStage(), ], enable_model_validation=True, ) @@ -20,7 +19,6 @@ [ export.ExportPlaceholder(), export.OptimizeOnnxModel(), - export.SuccessStage(), ], enable_model_validation=True, ) @@ -30,7 +28,6 @@ "Base Sequence", [ export.ExportPlaceholder(), - export.SuccessStage(), ], enable_model_validation=True, ) diff --git a/src/turnkeyml/build/stage.py b/src/turnkeyml/build/stage.py index 0267995e..daad45c0 100644 --- a/src/turnkeyml/build/stage.py +++ b/src/turnkeyml/build/stage.py @@ -2,6 +2,7 @@ import sys import time import os +import copy from typing import List, Tuple from multiprocessing import Process import psutil @@ -108,8 +109,7 @@ def fire_helper(self, state: build.State) -> Tuple[build.State, int]: # Set the build status to BUILD_RUNNING to indicate that a Stage # started running. This allows us to test whether the Stage exited - # unexpectedly, before it was able to set FAILED_BUILD, SUCCESSFUL_BUILD, - # or PARTIAL_BUILD + # unexpectedly, before it was able to set FAILED_BUILD state.build_status = build.Status.BUILD_RUNNING self.logfile_path = os.path.join( @@ -137,11 +137,16 @@ def fire_helper(self, state: build.State) -> Tuple[build.State, int]: else: self.status_line(successful=True, verbosity=state.monitor) - # Set the build status PARTIAL_BUILD, indicating that the stage - # ran successfully, unless the stage set SUCCESSFUL_BUILD, in which - # case leave the build status alone. - if state.build_status != build.Status.SUCCESSFUL_BUILD: - state.build_status = build.Status.PARTIAL_BUILD + # Stages should not set build.Status.SUCCESSFUL_BUILD, as that is + # reserved for Sequence.launch() + if state.build_status == build.Status.SUCCESSFUL_BUILD: + raise exp.StageError( + "TurnkeyML Stages are not allowed to set " + "`state.build_status == build.Status.SUCCESSFUL_BUILD`, " + "however that has happened. If you are a plugin developer, " + "do not do this. If you are a user, please file an issue at " + "https://github.com/onnx/turnkeyml/issues." + ) finally: if state.monitor: @@ -314,6 +319,14 @@ def launch(self, state: build.State) -> build.State: else: state.current_build_stage = None + state.build_status = build.Status.SUCCESSFUL_BUILD + + # We use a deepcopy here because the Stage framework supports + # intermediate_results of any type, including model objects in memory. + # The deepcopy ensures that we are providing a result that users + # are free to take any action with. + state.results = copy.deepcopy(state.intermediate_results) + return state def status_line(self, successful, verbosity): diff --git a/test/build_model.py b/test/build_model.py index 3e1b9ae6..4503294e 100644 --- a/test/build_model.py +++ b/test/build_model.py @@ -232,38 +232,6 @@ def scriptmodule_functional_check(): return state.build_status == build.Status.SUCCESSFUL_BUILD -def full_compile_individual_stages(): - build_name = "full_compile_individual_stages" - build_model( - pytorch_model, - inputs, - build_name=build_name, - rebuild="always", - monitor=False, - sequence=stage.Sequence( - "ExportPytorchModel_seq", "", [export.ExportPytorchModel()] - ), - cache_dir=cache_location, - ) - build_model( - build_name=build_name, - sequence=stage.Sequence("OptimizeModel_seq", "", [export.OptimizeOnnxModel()]), - cache_dir=cache_location, - ) - build_model( - build_name=build_name, - sequence=stage.Sequence("Fp16Conversion_seq", "", [export.ConvertOnnxToFp16()]), - cache_dir=cache_location, - ) - state = build_model( - build_name=build_name, - sequence=stage.Sequence("SuccessStage_seq", "", [export.SuccessStage()]), - cache_dir=cache_location, - ) - - return state.build_status == build.Status.SUCCESSFUL_BUILD - - def custom_stage(): build_name = "custom_stage" @@ -299,7 +267,6 @@ def fire(self, state): export.ExportPytorchModel(), export.OptimizeOnnxModel(), my_custom_stage, - export.SuccessStage(), ], ) @@ -537,9 +504,6 @@ def test_006_full_compilation_hummingbird_rf(self): def test_007_full_compilation_hummingbird_xgb(self): assert full_compilation_hummingbird_xgb() - def test_008_full_compile_individual_stages(self): - assert full_compile_individual_stages() - def test_009_custom_stage(self): assert custom_stage()