Skip to content

Commit

Permalink
Remove the SetSuccess stage (and the need for it) (#59)
Browse files Browse the repository at this point in the history
Signed-off-by: Jeremy Fowers <jeremy.fowers@amd.com>
  • Loading branch information
jeremyfowers committed Dec 6, 2023
1 parent 2e84a3d commit 4ed7516
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 69 deletions.
1 change: 0 additions & 1 deletion examples/build_api/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def forward(self, x):
export.ExportPytorchModel(),
export.OptimizeOnnxModel(),
# export.ConvertOnnxToFp16(), #<-- This is the step we want to skip
export.SuccessStage(),
],
enable_model_validation=True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def fire(self, state: build.State):
stages=[
export.ExportPlaceholder(),
CombinedExampleStage(),
export.SuccessStage(),
],
enable_model_validation=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def fire(self, state: build.State):
stages=[
export.ExportPlaceholder(),
ExampleStage(),
export.SuccessStage(),
],
enable_model_validation=True,
)
20 changes: 0 additions & 20 deletions src/turnkeyml/build/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,23 +620,3 @@ def fire(self, state: build.State):
raise exp.StageError(msg)

return state


class SuccessStage(stage.Stage):
"""
Stage that sets state.build_status = build.Status.SUCCESSFUL_BUILD,
indicating that the build sequence has completed all of the requested build stages.
"""

def __init__(self):
super().__init__(
unique_name="set_success",
monitor_message="Finishing up",
)

def fire(self, state: build.State):
state.build_status = build.Status.SUCCESSFUL_BUILD

state.results = copy.deepcopy(state.intermediate_results)

return state
3 changes: 0 additions & 3 deletions src/turnkeyml/build/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
export.ExportPlaceholder(),
export.OptimizeOnnxModel(),
export.ConvertOnnxToFp16(),
export.SuccessStage(),
],
enable_model_validation=True,
)
Expand All @@ -20,7 +19,6 @@
[
export.ExportPlaceholder(),
export.OptimizeOnnxModel(),
export.SuccessStage(),
],
enable_model_validation=True,
)
Expand All @@ -30,7 +28,6 @@
"Base Sequence",
[
export.ExportPlaceholder(),
export.SuccessStage(),
],
enable_model_validation=True,
)
Expand Down
27 changes: 20 additions & 7 deletions src/turnkeyml/build/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import time
import os
import copy
from typing import List, Tuple
from multiprocessing import Process
import psutil
Expand Down Expand Up @@ -108,8 +109,7 @@ def fire_helper(self, state: build.State) -> Tuple[build.State, int]:

# Set the build status to BUILD_RUNNING to indicate that a Stage
# started running. This allows us to test whether the Stage exited
# unexpectedly, before it was able to set FAILED_BUILD, SUCCESSFUL_BUILD,
# or PARTIAL_BUILD
# unexpectedly, before it was able to set FAILED_BUILD
state.build_status = build.Status.BUILD_RUNNING

self.logfile_path = os.path.join(
Expand Down Expand Up @@ -137,11 +137,16 @@ def fire_helper(self, state: build.State) -> Tuple[build.State, int]:
else:
self.status_line(successful=True, verbosity=state.monitor)

# Set the build status PARTIAL_BUILD, indicating that the stage
# ran successfully, unless the stage set SUCCESSFUL_BUILD, in which
# case leave the build status alone.
if state.build_status != build.Status.SUCCESSFUL_BUILD:
state.build_status = build.Status.PARTIAL_BUILD
# Stages should not set build.Status.SUCCESSFUL_BUILD, as that is
# reserved for Sequence.launch()
if state.build_status == build.Status.SUCCESSFUL_BUILD:
raise exp.StageError(
"TurnkeyML Stages are not allowed to set "
"`state.build_status == build.Status.SUCCESSFUL_BUILD`, "
"however that has happened. If you are a plugin developer, "
"do not do this. If you are a user, please file an issue at "
"https://github.com/onnx/turnkeyml/issues."
)

finally:
if state.monitor:
Expand Down Expand Up @@ -314,6 +319,14 @@ def launch(self, state: build.State) -> build.State:

else:
state.current_build_stage = None
state.build_status = build.Status.SUCCESSFUL_BUILD

# We use a deepcopy here because the Stage framework supports
# intermediate_results of any type, including model objects in memory.
# The deepcopy ensures that we are providing a result that users
# are free to take any action with.
state.results = copy.deepcopy(state.intermediate_results)

return state

def status_line(self, successful, verbosity):
Expand Down
36 changes: 0 additions & 36 deletions test/build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,38 +232,6 @@ def scriptmodule_functional_check():
return state.build_status == build.Status.SUCCESSFUL_BUILD


def full_compile_individual_stages():
build_name = "full_compile_individual_stages"
build_model(
pytorch_model,
inputs,
build_name=build_name,
rebuild="always",
monitor=False,
sequence=stage.Sequence(
"ExportPytorchModel_seq", "", [export.ExportPytorchModel()]
),
cache_dir=cache_location,
)
build_model(
build_name=build_name,
sequence=stage.Sequence("OptimizeModel_seq", "", [export.OptimizeOnnxModel()]),
cache_dir=cache_location,
)
build_model(
build_name=build_name,
sequence=stage.Sequence("Fp16Conversion_seq", "", [export.ConvertOnnxToFp16()]),
cache_dir=cache_location,
)
state = build_model(
build_name=build_name,
sequence=stage.Sequence("SuccessStage_seq", "", [export.SuccessStage()]),
cache_dir=cache_location,
)

return state.build_status == build.Status.SUCCESSFUL_BUILD


def custom_stage():
build_name = "custom_stage"

Expand Down Expand Up @@ -299,7 +267,6 @@ def fire(self, state):
export.ExportPytorchModel(),
export.OptimizeOnnxModel(),
my_custom_stage,
export.SuccessStage(),
],
)

Expand Down Expand Up @@ -537,9 +504,6 @@ def test_006_full_compilation_hummingbird_rf(self):
def test_007_full_compilation_hummingbird_xgb(self):
assert full_compilation_hummingbird_xgb()

def test_008_full_compile_individual_stages(self):
assert full_compile_individual_stages()

def test_009_custom_stage(self):
assert custom_stage()

Expand Down

0 comments on commit 4ed7516

Please sign in to comment.