Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the deprecated quantization tool #53

Merged
merged 2 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/coverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ Name Stmts Miss Branch BrPart Cover Mi
--------------------------------------------------------------------------------------------------------
\turnkeyml\build\__init__.py 0 0 0 0 100%
\turnkeyml\build\onnx_helpers.py 70 34 28 2 45% 15-21, 28-87, 92, 95-100
\turnkeyml\build\quantization_helpers.py 29 20 18 0 19% 13-30, 35, 50-78
\turnkeyml\build\sequences.py 15 1 8 2 87% 62->61, 65
\turnkeyml\build\tensor_helpers.py 47 26 34 4 41% 17-44, 57, 61, 63-74, 78
\turnkeyml\build_api.py 31 9 8 3 64% 68-71, 120-125, 140-147
Expand Down
63 changes: 1 addition & 62 deletions src/turnkeyml/build/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import turnkeyml.common.build as build
import turnkeyml.build.tensor_helpers as tensor_helpers
import turnkeyml.build.onnx_helpers as onnx_helpers
import turnkeyml.build.quantization_helpers as quant_helpers
import turnkeyml.common.filesystem as fs


Expand Down Expand Up @@ -77,13 +76,6 @@ def converted_onnx_file(state: build.State):
)


def quantized_onnx_file(state: build.State):
return os.path.join(
onnx_dir(state),
f"{state.config.build_name}-op{state.config.onnx_opset}-opt-quantized_int8.onnx",
)


class ExportPlaceholder(stage.Stage):
"""
Placeholder Stage that should be replaced by a framework-specific export stage,
Expand Down Expand Up @@ -571,9 +563,8 @@ def fire(self, state: build.State):
inputs_file = state.original_inputs_file
if os.path.isfile(inputs_file):
inputs = np.load(inputs_file, allow_pickle=True)
to_downcast = False if state.quantization_samples else True
inputs_converted = tensor_helpers.save_inputs(
inputs, inputs_file, downcast=to_downcast
inputs, inputs_file, downcast=True
jeremyfowers marked this conversation as resolved.
Show resolved Hide resolved
)
else:
raise exp.StageError(
Expand Down Expand Up @@ -621,58 +612,6 @@ def fire(self, state: build.State):
return state


class QuantizeONNXModel(stage.Stage):
"""
Stage that takes an ONNX model and a dataset of quantization samples as inputs,
and performs static post-training quantization to the model to int8 precision.

Expected inputs:
- state.model is a path to the ONNX model
- state.quantization_dataset is a dataset that is used for static quantization

Outputs:
- A *_quantized.onnx file => the quantized onnx model.
"""

def __init__(self):
super().__init__(
unique_name="quantize_onnx",
monitor_message="Quantizing ONNX model",
)

def fire(self, state: build.State):
input_path = state.intermediate_results[0]
output_path = quantized_onnx_file(state)

quant_helpers.quantize(
input_file=input_path,
data=state.quantization_samples,
output_file=output_path,
)

# Check that the converted model is still valid
success_msg = "\tSuccess quantizing ONNX model to int8"
fail_msg = "\tFailed quantizing ONNX model to int8"

if check_model(output_path, success_msg, fail_msg):
state.intermediate_results = [output_path]

stats = fs.Stats(state.cache_dir, state.config.build_name, state.stats_id)
stats.add_build_stat(
fs.Keys.ONNX_FILE,
output_path,
)
else:
msg = f"""
Attempted to use {state.quantization_dataset} to statically quantize
model to int8 datatype, however this operation was not successful.
More information may be available in the log file at **{self.logfile_path}**
"""
raise exp.StageError(msg)

return state


class SuccessStage(stage.Stage):
"""
Stage that sets state.build_status = build.Status.SUCCESSFUL_BUILD,
Expand Down
65 changes: 5 additions & 60 deletions src/turnkeyml/build/ignition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Optional, List, Tuple, Union, Dict, Any, Type, Callable
from collections.abc import Collection
import sys
import os
import copy
Expand Down Expand Up @@ -258,7 +257,6 @@ def load_or_make_state(
monitor: bool,
model: build.UnionValidModelInstanceTypes = None,
inputs: Optional[Dict[str, Any]] = None,
quantization_samples: Optional[Collection] = None,
state_type: Type = build.State,
cache_validation_func: Callable = validate_cached_model,
extra_state_args: Optional[Dict] = None,
Expand All @@ -280,7 +278,6 @@ def load_or_make_state(
"cache_dir": cache_dir,
"config": config,
"model_type": model_type,
"quantization_samples": quantization_samples,
}

# Ensure that `rebuild` has a valid value
Expand All @@ -306,50 +303,6 @@ def load_or_make_state(
state_type=state_type,
)

# if the previous build is using quantization while the current is not
# or vice versa
if state.quantization_samples and quantization_samples is None:
if rebuild == "never":
msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample enabled."
"However, post-training quantization is not enabled in the "
"current build. Rebuild is necessary but currently the rebuild"
"policy is set to 'never'. "
)
raise exp.CacheError(msg)

msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample enabled."
"However, post-training quantization is not enabled in the "
"current build. Starting a fresh build."
)

printing.log_info(msg)
return _begin_fresh_build(state_args, state_type)

if not state.quantization_samples and quantization_samples is not None:
if rebuild == "never":
msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample disabled."
"However, post-training quantization is enabled in the "
"current build. Rebuild is necessary but currently the rebuild"
"policy is set to 'never'. "
)
raise exp.CacheError(msg)

msg = (
f"Model {config.build_name} was built in a previous call to "
"build_model() with post-training quantization sample disabled."
"However, post-training quantization is enabled in the "
"current build. Starting a fresh build."
)

printing.log_info(msg)
return _begin_fresh_build(state_args, state_type)

except exp.StateError as e:
problem = (
"- build_model() failed to load "
Expand Down Expand Up @@ -500,7 +453,6 @@ def model_intake(
user_model,
user_inputs,
user_sequence: Optional[stage.Sequence],
user_quantization_samples: Optional[Collection] = None,
) -> Tuple[Any, Any, stage.Sequence, build.ModelType, str]:
# Model intake structure options:
# user_model
Expand Down Expand Up @@ -550,18 +502,11 @@ def model_intake(

sequence = copy.deepcopy(user_sequence)
if sequence is None:
if user_quantization_samples:
if model_type != build.ModelType.PYTORCH:
raise exp.IntakeError(
"Currently, post training quantization only supports Pytorch models."
)
sequence = sequences.pytorch_with_quantization
else:
sequence = stage.Sequence(
"top_level_sequence",
"Top Level Sequence",
[sequences.onnx_fp32],
)
sequence = stage.Sequence(
"top_level_sequence",
"Top Level Sequence",
[sequences.onnx_fp32],
)

# If there is an ExportPlaceholder Stage in the sequence, replace it with
# a framework-specific export Stage.
Expand Down
78 changes: 0 additions & 78 deletions src/turnkeyml/build/quantization_helpers.py

This file was deleted.

12 changes: 0 additions & 12 deletions src/turnkeyml/build/sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,6 @@
enable_model_validation=True,
)

pytorch_with_quantization = stage.Sequence(
"pytorch_export_sequence_with_quantization",
"Exporting PyTorch Model and Quantizing Exported ONNX",
[
export.ExportPytorchModel(),
export.OptimizeOnnxModel(),
export.QuantizeONNXModel(),
export.SuccessStage(),
],
enable_model_validation=True,
)

# Plugin interface for sequences
discovered_plugins = plugins.discover()

Expand Down
11 changes: 0 additions & 11 deletions src/turnkeyml/build_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
from typing import Optional, List, Dict, Any
from collections.abc import Collection
import turnkeyml.build.ignition as ignition
import turnkeyml.build.stage as stage
import turnkeyml.common.printing as printing
Expand All @@ -17,7 +16,6 @@ def build_model(
monitor: Optional[bool] = None,
rebuild: Optional[str] = None,
sequence: Optional[List[stage.Stage]] = None,
quantization_samples: Collection = None,
onnx_opset: Optional[int] = None,
device: Optional[str] = None,
) -> build.State:
Expand Down Expand Up @@ -48,11 +46,6 @@ def build_model(
- None: Falls back to default
sequence: Override the default sequence of build stages. Power
users only.
quantization_samples: If set, performs post-training quantization
on the ONNX model using the provided samplesIf the previous build used samples
that are different to the samples used in current build, the "rebuild"
argument needs to be manually set to "always" in the current build
in order to create a new ONNX file.
onnx_opset: ONNX opset to use during ONNX export.
device: Specific device target to take into account during the build sequence.
Use the format "device_family", "device_family::part", or
Expand Down Expand Up @@ -96,7 +89,6 @@ def build_model(
model,
inputs,
sequence,
user_quantization_samples=quantization_samples,
)

# Get the state of the model from the cache if a valid build is available
Expand All @@ -109,7 +101,6 @@ def build_model(
monitor=monitor_setting,
model=model_locked,
inputs=inputs_locked,
quantization_samples=quantization_samples,
)

# Return a cached build if possible, otherwise prepare the model State for
Expand All @@ -124,8 +115,6 @@ def build_model(

return state

state.quantization_samples = quantization_samples

sequence_locked.show_monitor(config, state.monitor)
state = sequence_locked.launch(state)

Expand Down
Loading
Loading