diff --git a/.github/workflows/publish-to-test-pypi.yml b/.github/workflows/publish-to-test-pypi.yml index 0808cdb6..09c9aac7 100644 --- a/.github/workflows/publish-to-test-pypi.yml +++ b/.github/workflows/publish-to-test-pypi.yml @@ -5,6 +5,7 @@ on: branches: ["main", "canary"] tags: - v* + - RC* pull_request: branches: ["main", "canary"] @@ -33,7 +34,13 @@ jobs: models=$(turnkey models location --quiet) turnkey $models/selftest/linear.py - name: Publish distribution package to PyPI - if: startsWith(github.ref, 'refs/tags') + if: startsWith(github.ref, 'refs/tags/v') uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} + - name: Publish distribution package to Test PyPI + if: startsWith(github.ref, 'refs/tags/RC') + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.TEST_PYPI_API_TOKEN }} + repository_url: https://test.pypi.org/legacy/ diff --git a/.github/workflows/test_turnkey.yml b/.github/workflows/test_turnkey.yml index 915bc47a..1ed8168e 100644 --- a/.github/workflows/test_turnkey.yml +++ b/.github/workflows/test_turnkey.yml @@ -53,8 +53,6 @@ jobs: # turnkey examples # Note: we clear the default cache location prior to each example run rm -rf ~/.cache/turnkey - python examples/model_api/hello_world.py - rm -rf ~/.cache/turnkey python examples/files_api/onnx_opset.py --onnx-opset 15 rm -rf ~/.cache/turnkey turnkey examples/cli/scripts/hello_world.py @@ -71,7 +69,6 @@ jobs: cd test/ python cli.py python analysis.py - python model_api.py - name: Test example plugins shell: bash -el {0} run: | diff --git a/docs/code.md b/docs/code.md index e9fdc7a0..28663677 100644 --- a/docs/code.md +++ b/docs/code.md @@ -11,8 +11,8 @@ The TurnkeyML source code has a few major top-level directories: - `models`: the corpora of models that makes up the TurnkeyML models (see [the models readme](https://github.com/onnx/turnkeyml/blob/main/models/readme.md)). - Each subdirectory under `models` represents a corpus of models pulled from somewhere on the internet. For example, `models/torch_hub` is a corpus of models from [Torch Hub](https://github.com/pytorch/hub). - `src/turnkey`: source code for the TurnkeyML tools (see [Benchmarking Tools](#benchmarking-tools) for a description of how the code is used). - - `src/turnkeyml/analyze`: functions for profiling a model script, discovering model instances, and invoking `benchmark_model()` on those instances. - - `src/turnkeyml/run`: implements the runtime and device plugin APIs and the built-in runtimes and devices. + - `src/turnkeyml/analyze`: functions for profiling a model script, discovering model instances, and invoking `build_model()` and/or `BaseRT.benchmark()` on those instances. + - `src/turnkeyml/run`: implements `BaseRT`, an abstract base class that defines TurnkeyML's vendor-agnostic benchmarking functionality. This module also includes the runtime and device plugin APIs and the built-in runtimes and devices. - `src/turnkeyml/cli`: implements the `turnkey` CLI and reporting tool. - `src/turnkeyml/common`: functions common to the other modules. - `src/turnkeyml/version.py`: defines the package version number. @@ -29,10 +29,9 @@ TurnkeyML provides two main tools, the `turnkey` CLI and benchmarking APIs. Inst 1. The default command for `turnkey` CLI runs the `benchmark_files()` API, which is implemented in [files_api.py](https://github.com/onnx/turnkeyml/blob/main/src/turnkeyml/files_api.py). - Other CLI commands are also implemented in `cli/`, for example the `report` command is implemented in `cli/report.py`. 1. The `benchmark_files()` API takes in a set of scripts, each of which should invoke at least one model instance, to evaluate and passes each into the `evaluate_script()` function for analysis, which is implemented in [analyze/script.py](https://github.com/onnx/turnkeyml/blob/main/src/turnkeyml/analyze/script.py). -1. `evaluate_script()` uses a profiler to discover the model instances in the script, and passes each into the `benchmark_model()` API, which is defined in [model_api.py](https://github.com/onnx/turnkeyml/blob/main/src/turnkeyml/model_api.py). -1. The `benchmark_model()` API prepares the model for benchmarking (e.g., exporting and optimizing an ONNX file), which creates an instance of a `*Model` class, where `*` can be CPU, GPU, etc. The `*Model` classes are defined in [run/](https://github.com/onnx/turnkeyml/blob/main/src/turnkeyml/run/). -1. The `*Model` classes provide a `.benchmark()` method that benchmarks the model on the device and returns an instance of the `MeasuredPerformance` class, which includes the performance statistics acquired during benchmarking. -1. `benchmark_model()` and the `*Model` classes are built using [`build_model()`](#model-build-tool) +1. `evaluate_script()` uses a profiler to discover the model instances in the script, and passes each into the `build_model()` API, which is defined in [build_api.py](https://github.com/onnx/turnkeyml/blob/main/src/turnkeyml/build_api.py). +1. The `build_model()` API prepares the model for benchmarking (e.g., exporting and optimizing an ONNX file). +1. `evaluate_script()` passes the build into `BaseRT.benchmark()` to benchmarks the model on the device and returns an instance of the `MeasuredPerformance` class, which includes the performance statistics acquired during benchmarking. # Model Build Tool diff --git a/docs/readme.md b/docs/readme.md index 55ec96c0..5b26c33c 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -3,7 +3,7 @@ This directory contains documentation for the TurnkeyML project: - [code.md](https://github.com/onnx/turnkeyml/blob/main/docs/code.md): Code organization for the benchmark and tools. - [install.md](https://github.com/onnx/turnkeyml/blob/main/docs/install.md): Installation instructions for the tools. -- [tools_user_guide.md](https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md): User guide for the tools: `turnkey` CLI, `benchmark_files()`, and `benchmark_model()`. +- [tools_user_guide.md](https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md): User guide for the tools: the `turnkey` CLI and the `benchmark_files()` and `build_model()` APIs. - [versioning.md](https://github.com/onnx/turnkeyml/blob/main/docs/versioning.md): Defines the semantic versioning rules for the `turnkey` package. There is more useful documentation available in: diff --git a/docs/tools_user_guide.md b/docs/tools_user_guide.md index d9819979..53d24635 100644 --- a/docs/tools_user_guide.md +++ b/docs/tools_user_guide.md @@ -51,8 +51,8 @@ Where `your_script.py` is a Python script that instantiates and executes a PyTor The `turnkey` CLI performs the following steps: 1. [Analysis](#analysis): profile the Python script to identify the PyTorch models within -2. [Build](#build): call the `benchmark_files()` [API](#the-turnkey-api) to prepare each model for benchmarking -3. [Benchmark](#benchmark): call the `benchmark_model()` [API](#the-turnkey-api) on each model to gather performance statistics +2. [Build](#build): call the `build_models()` [API](#the-turnkey-api) to prepare each model for benchmarking +3. [Benchmark](#benchmark): call the `BaseRT.benchmark()` method on each model to gather performance statistics _Note_: The benchmarking methodology is defined [here](#benchmark). If you are looking for more detailed instructions on how to install turnkey, you can find that [here](https://github.com/onnx/turnkeyml/blob/main/docs/install.md). @@ -64,31 +64,11 @@ _Note_: The benchmarking methodology is defined [here](#benchmark). If you are l Most of the functionality provided by the `turnkey` CLI is also available in the the API: - `turnkey.benchmark_files()` provides the same benchmarking functionality as the `turnkey` CLI: it takes a list of files and target device, and returns performance results. -- `turnkey.benchmark_model()` provides a subset of this functionality: it takes a model and its inputs, and returns performance results. - - The main difference is that `benchmark_model()` does not include the [Analysis](#analysis) feature, and `benchmark_files()` does. - `turnkey.build_model(model, inputs)` is used to programmatically [build](#build) a model instance through a sequence of model-to-model transformations (e.g., starting with an fp32 PyTorch model and ending with an fp16 ONNX model). -Generally speaking, the `turnkey` CLI is a command line interface for the `benchmark_files()` API, which internally calls `benchmark_model()`, which in turn calls `build_model()`. You can read more about this code organization [here](https://github.com/onnx/turnkeyml/blob/main/docs/code.md). +Generally speaking, the `turnkey` CLI is a command line interface for the `benchmark_files()` API which in turn calls `build_model()` and then performs benchmarking using `BaseRT.benchmark()`. You can read more about this code organization [here](https://github.com/onnx/turnkeyml/blob/main/docs/code.md). -For an example of `benchmark_model()`, the following script: - -```python -from turnkeyml import benchmark_model - -model = YourModel() # Instantiate a torch.nn.module -results = model(**inputs) -perf = benchmark_model(model, inputs) -``` - -Will print an output like this: - -``` -> Performance of YourModel on device Intel® Xeon® Platinum 8380 is: -> latency: 0.033 ms -> throughput: 21784.8 ips -``` - -`benchmark_model()` returns a `MeasuredPerformance` object that includes members: +`BaseRT.benchmark()` returns a `MeasuredPerformance` object that includes members: - `latency_units`: unit of time used for measuring latency, which is set to `milliseconds (ms)`. - `mean_latency`: average benchmarking latency, measured in `latency_units`. - `throughput_units`: unit used for measuring throughput, which is set to `inferences per second (IPS)`. @@ -135,7 +115,7 @@ A **runtime** is a piece of software that executes a model on a device. **Analysis** is the process by which `benchmark_files()` inspects a Python script or ONNX file and identifies the models within. -`benchmark_files()` performs analysis by running and profiling your file(s). When a model object (see [Model](#model) is encountered, it is inspected to gather statistics (such as the number of parameters in the model) and/or pass it to the `benchmark_model()` API for benchmarking. +`benchmark_files()` performs analysis by running and profiling your file(s). When a model object (see [Model](#model) is encountered, it is inspected to gather statistics (such as the number of parameters in the model) and/or passed to the build and benchmark APIs. > _Note_: the `turnkey` CLI and `benchmark_files()` API both run your entire python script(s) whenever python script(s) are passed as input files. Please ensure that these scripts are safe to run, especially if you got them from the internet. @@ -205,12 +185,14 @@ The *build cache* is a location on disk that holds all of the artifacts from you ## Benchmark -*Benchmark* is the process by which the `benchmark_model()` API collects performance statistics about a [model](#model). Specifically, `benchmark_model()` takes a [build](#build) of a model and executes it on a target device using target runtime software (see [Devices and Runtimes](#devices-and-runtimes)). +*Benchmark* is the process by which `BaseRT.benchmark()` collects performance statistics about a [model](#model). `BaseRT` is an abstract base class that defines the common benchmarking infrastructure that TurnkeyML provides across devices and runtimes. + +Specifically, `BaseRT.benchmark()` takes a [build](#build) of a model and executes it on a target device using target runtime software (see [Devices and Runtimes](#devices-and-runtimes)). -By default, `benchmark_model()` will run the model 100 times to collect the following statistics: +By default, `BaseRT.benchmark()` will run the model 100 times to collect the following statistics: 1. Mean Latency, in milliseconds (ms): the average time it takes the runtime/device combination to execute the model/inputs combination once. This includes the time spent invoking the device and transferring the model's inputs and outputs between host memory and the device (when applicable). 1. Throughput, in inferences per second (IPS): the number of times the model/inputs combination can be executed on the runtime/device combination per second. - > - _Note_: `benchmark_model()` is not aware of whether `inputs` is a single input or a batch of inputs. If your `inputs` is actually a batch of inputs, you should multiply `benchmark_model()`'s reported IPS by the batch size. + > - _Note_: `BaseRT.benchmark()` is not aware of whether `inputs` is a single input or a batch of inputs. If your `inputs` is actually a batch of inputs, you should multiply `BaseRT.benchmark()`'s reported IPS by the batch size. # Devices and Runtimes @@ -226,7 +208,7 @@ If you are using a remote machine, it must: - include the target device - have `miniconda`, `python>=3.8`, and `docker>=20.10` installed -When you call `turnkey` CLI or `benchmark_model()`, the following actions are performed on your behalf: +When you call `turnkey` CLI or `benchmark_files()`, the following actions are performed on your behalf: 1. Perform a `build`, which exports all models from the script to ONNX and prepares for benchmarking. 1. Set up the benchmarking environment by loading a container and/or setting up a conda environment. 1. Run the benchmarks. @@ -253,7 +235,6 @@ Valid values of `TYPE` include: Also available as API arguments: - `benchmark_files(device=...)` -- `benchmark_model(device=...)`. > For a detailed example, see the [CLI Nvidia tutorial](https://github.com/onnx/turnkeyml/blob/main/examples/cli/readme.md#nvidia-benchmarking). @@ -274,9 +255,8 @@ Each device type has its own default runtime, as indicated below. This feature is also be available as an API argument: - `benchmark_files(runtime=[...])` -- `benchmark_model(runtime=...)` -> _Note_: Inputs to `torch-eager` and `torch-compiled` are not downcasted to FP16 by default. Downcast inputs before benchmarking for a fair comparison between runtimes. +> _Note_: Inputs to `torch-eager` and `torch-compiled` are not downcasted to FP16 by default. You must perform your own downcast or quantization of inputs if needed for apples-to-apples comparisons with other runtimes. # Additional Commands and Options @@ -381,7 +361,6 @@ Process isolation mode applies a timeout to each subprocess. The default timeout Also available as API arguments: - `benchmark_files(cache_dir=...)` -- `benchmark_model(cache_dir=...)` - `build_model(cache_dir=...)` > See the [Cache Directory tutorial](https://github.com/onnx/turnkeyml/blob/main/examples/cli/cache.md#cache-directory) for a detailed example. @@ -392,7 +371,6 @@ Also available as API arguments: Also available as API arguments: - `benchmark_files(lean_cache=True/False, ...)` (default False) -- `benchmark_model(lean_cache=True/False, ...)` (default False) > _Note_: useful for benchmarking many models, since the `build` artifacts from the models can take up a significant amount of hard drive space. @@ -409,7 +387,6 @@ Takes one of the following values: Also available as API arguments: - `benchmark_files(rebuild=...)` -- `benchmark_model(rebuild=...)` - `build_model(rebuild=...)` ### Sequence @@ -421,7 +398,6 @@ Usage: Also available as API arguments: - `benchmark_files(sequence=...)` -- `benchmark_model(sequence=...)` - `build_model(sequence=...)` ### Set Script Arguments @@ -460,7 +436,6 @@ Usage: Also available as API arguments: - `benchmark_files(onnx_opset=...)` -- `benchmark_model(onnx_opset=...)` - `build_model(onnx_opset=...)` > _Note_: ONNX opset can also be set by an environment variable. The --onnx-opset argument takes precedence over the environment variable. See [TURNKEY_ONNX_OPSET](#set-the-onnx-opset). @@ -474,11 +449,10 @@ Usage: Also available as API arguments: - `benchmark_files(iterations=...)` -- `benchmark_model(iterations=...)` ### Analyze Only -Instruct `turnkey` or `benchmark_model()` to only run the [Analysis](#analysis) phase of the `benchmark` command. +Instruct `turnkey` or `benchmark_files()` to only run the [Analysis](#analysis) phase of the `benchmark` command. Usage: - `turnkey benchmark INPUT_FILES --analyze-only` @@ -493,7 +467,7 @@ Also available as an API argument: ### Build Only -Instruct `turnkey`, `benchmark_files()`, or `benchmark_model()` to only run the [Analysis](#analysis) and [Build](#build) phases of the `benchmark` command. +Instruct `turnkey` or `benchmark_files()` to only run the [Analysis](#analysis) and [Build](#build) phases of the `benchmark` command. Usage: - `turnkey benchmark INPUT_FILES --build-only` @@ -503,7 +477,6 @@ Usage: Also available as API arguments: - `benchmark_files(build_only=True/False)` (default False) -- `benchmark_model(build_only=True/False)` (default False) > See the [Build Only tutorial](https://github.com/onnx/turnkeyml/blob/main/examples/cli/build.md#build-only) for a detailed example. @@ -515,7 +488,6 @@ None of the built-in runtimes support such arguments, however plugin contributor Also available as API arguments: - `benchmark_files(rt_args=Dict)` (default None) -- `benchmark_model(rt_args=Dict)` (default None) ## Cache Commands @@ -635,7 +607,7 @@ export TURNKEY_DEBUG=True ### Set the ONNX Opset -By default, `turnkey`, `benchmark_files()`, and `benchmark_model()` will use the default ONNX opset defined in `turnkey.common.build.DEFAULT_ONNX_OPSET`. You can set a different default ONNX opset by setting the `TURNKEY_ONNX_OPSET` environment variable. +By default, `turnkey`, `benchmark_files()`, and `build_model()` will use the default ONNX opset defined in `turnkey.common.build.DEFAULT_ONNX_OPSET`. You can set a different default ONNX opset by setting the `TURNKEY_ONNX_OPSET` environment variable. For example: diff --git a/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py b/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py index 6c1a1229..350f2a76 100644 --- a/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py +++ b/examples/cli/plugins/example_seq/turnkeyml_plugin_example_seq/sequence.py @@ -1,7 +1,7 @@ """ This script is an example of a sequence.py file for Sequence Plugin. Such a sequence.py can be used to redefine the build phase of the turnkey CLI, benchmark_files(), -and benchmark_model() to have any custom behavior. +and build_model() to have any custom behavior. In this example sequence.py file we are setting the build sequence to simply export from pytorch to ONNX. This differs from the default build sequence, which diff --git a/examples/model_api/hello_world.py b/examples/model_api/hello_world.py deleted file mode 100644 index 6a5a4a6f..00000000 --- a/examples/model_api/hello_world.py +++ /dev/null @@ -1,62 +0,0 @@ -import argparse -import torch -from turnkeyml import benchmark_model - -torch.manual_seed(0) - - -# Define model class -class SmallModel(torch.nn.Module): - def __init__(self, input_size, output_size): - super(SmallModel, self).__init__() - self.fc = torch.nn.Linear(input_size, output_size) - - def forward(self, x): - output = self.fc(x) - return output - - -# Instantiate model and generate inputs -input_size = 1000 -output_size = 500 -pytorch_model = SmallModel(input_size, output_size) -inputs = {"x": torch.rand(input_size)} - - -def main(): - # Define the argument parser - parser = argparse.ArgumentParser( - description="Benchmark a PyTorch model on a specified device." - ) - - # Add the arguments - parser.add_argument( - "--device", - type=str, - choices=["x86", "nvidia"], - default="x86", - help="The device to benchmark on (x86 or nvidia)", - ) - - # Parse the arguments - args = parser.parse_args() - - # Instantiate model and generate inputs - torch.manual_seed(0) - input_size = 1000 - output_size = 500 - pytorch_model = SmallModel(input_size, output_size) - inputs = {"x": torch.rand(input_size)} - - # Benchmark the model on the specified device - print(f"Benchmarking on {args.device}...") - benchmark_model( - pytorch_model, - inputs, - build_name="hello_api_world", - device=args.device, - ) - - -if __name__ == "__main__": - main() diff --git a/examples/readme.md b/examples/readme.md index 8f0c2120..12ae689f 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -2,6 +2,5 @@ This directory contains examples to help you learn how to use the tools. The examples are split up into two sub-directories: 1. `examples/cli`: a tutorial series for the `turnkey` CLI. This is the recommended starting point. -1. `examples/model_api`: scripts that demonstrate how to use the `turnkey.benchmark_model()` API. 1. `examples/files_api`: scripts that demonstrate how to use the `turnkey.benchmark_files()` API. 1. `examples/build_api`: scripts that demonstrate how to use the `turnkey.build_model()` API. diff --git a/src/turnkeyml/__init__.py b/src/turnkeyml/__init__.py index 175c5891..0430c4ce 100644 --- a/src/turnkeyml/__init__.py +++ b/src/turnkeyml/__init__.py @@ -1,7 +1,6 @@ from turnkeyml.version import __version__ from .files_api import benchmark_files -from .model_api import benchmark_model from .cli.cli import main as turnkeycli from .build_api import build_model from .common.build import load_state diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index cebf7999..3c288909 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -22,13 +22,9 @@ import turnkeyml.analyze.util as util import turnkeyml.common.tf_helpers as tf_helpers import turnkeyml.common.labels as labels -from turnkeyml.model_api import benchmark_model +from turnkeyml.build_api import build_model import turnkeyml.common.filesystem as fs -from turnkeyml.run.devices import ( - DEVICE_RUNTIME_MAP, - DEFAULT_RUNTIME, - SUPPORTED_RUNTIMES, -) +import turnkeyml.run.devices as plugins class Action(Enum): @@ -88,6 +84,15 @@ def _store_traceback(invocation_info: util.UniqueInvocationInfo): invocation_info.status_message = " ".join(invocation_info.status_message.split()) +def set_status_on_exception(build_state: build.State, stats: fs.Stats): + # We get `state` when the build tool succeeds, so we can use that to identify + # whether the exception was thrown during build or benchmark + if not build_state: + stats.save_model_eval_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.FAILED) + else: + stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.FAILED) + + def explore_invocation( model_inputs: dict, model_info: util.ModelInfo, @@ -138,21 +143,28 @@ def explore_invocation( inputs[all_args[i]] = args[i] invocation_info.inputs = inputs + # Create a build directory in the cache + fs.make_build_dir(tracer_args.cache_dir, build_name) + # If the user has not provided a specific runtime, select the runtime # based on the device provided. - if tracer_args.runtime is None: - selected_runtime = DEVICE_RUNTIME_MAP[tracer_args.device][DEFAULT_RUNTIME] - else: - selected_runtime = tracer_args.runtime + ( + selected_runtime, + runtime_info, + sequence_selected, + ) = plugins.select_runtime_and_sequence( + tracer_args.device, + tracer_args.runtime, + tracer_args.sequence, + ) - runtime_info = SUPPORTED_RUNTIMES[selected_runtime] if "status_stats" in runtime_info.keys(): invocation_info.stats_keys = runtime_info["status_stats"] else: invocation_info.stats_keys = [] # Create an ID for the build stats by combining the device and runtime. - # We don't need more info in the evaluation_id because changes to benchmark_model() + # We don't need more info in the evaluation_id because changes to build_model() # arguments (e.g., sequence) will trigger a rebuild, which is intended to replace the # build stats so long as the device and runtime have not changed. evaluation_id = f"{tracer_args.device}_{selected_runtime}" @@ -184,6 +196,13 @@ def explore_invocation( if fs.Keys.TASK in tracer_args.labels: stats.save_model_stat(fs.Keys.TASK, tracer_args.labels[fs.Keys.TASK][0]) + # Save the system information used for this evaluation + system_info = build.get_system_info() + stats.save_model_stat( + fs.Keys.SYSTEM_INFO, + system_info, + ) + # Save all of the lables in one place stats.save_model_stat(fs.Keys.LABELS, tracer_args.labels) @@ -219,72 +238,120 @@ def explore_invocation( tracer_args.iterations, ) + if model_info.model_type == build.ModelType.PYTORCH_COMPILED: + invocation_info.status_message = ( + "Skipping model compiled using torch.compile(). " + "turnkey requires models to be in eager mode " + "(regardless of what runtime you have selected)." + ) + invocation_info.status_message_color = printing.Colors.WARNING + + return + + build_state = None perf = None try: - if model_info.model_type == build.ModelType.PYTORCH_COMPILED: - invocation_info.status_message = ( - "Skipping model compiled using torch.compile(). " - "turnkey requires models to be in eager mode " - "(regardless of what runtime you have selected)." - ) - invocation_info.status_message_color = printing.Colors.WARNING - else: - # Indicate that the benchmark is running. If the build fails for any reason, + # Run the build tool (if needed by the runtime) + if runtime_info["build_required"]: + # Indicate that the build is running. If the build fails for any reason, # we will try to catch the exception and note it in the stats. # If a concluded build still has a status of "running", this means # there was an uncaught exception. - stats.save_model_eval_stat( - fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.RUNNING - ) + stats.save_model_eval_stat(fs.Keys.BUILD_STATUS, fs.FunctionStatus.RUNNING) - perf = benchmark_model( - model_info.model, - inputs, + build_state = build_model( + model=model_info.model, + inputs=inputs, evaluation_id=evaluation_id, - device=tracer_args.device, - runtime=selected_runtime, build_name=build_name, - iterations=tracer_args.iterations, cache_dir=tracer_args.cache_dir, - build_only=Action.BENCHMARK not in tracer_args.actions, - lean_cache=tracer_args.lean_cache, - sequence=tracer_args.sequence, - onnx_opset=tracer_args.onnx_opset, rebuild=tracer_args.rebuild, - rt_args=tracer_args.rt_args, + sequence=sequence_selected, + onnx_opset=tracer_args.onnx_opset, + device=tracer_args.device, + ) + + stats.save_model_eval_stat( + fs.Keys.BUILD_STATUS, fs.FunctionStatus.SUCCESSFUL ) - if Action.BENCHMARK in tracer_args.actions: - invocation_info.status_message = "Model successfully benchmarked!" - invocation_info.performance = perf - invocation_info.status_message_color = printing.Colors.OKGREEN + + model_to_benchmark = build_state.results[0] + + # Analyze the onnx file (if any) and save statistics + util.analyze_onnx( + build_name=build_name, + cache_dir=tracer_args.cache_dir, + stats=stats, + ) + else: + model_to_benchmark = model_info.model + + # Run the benchmark tool (if requested by the user) + if Action.BENCHMARK in tracer_args.actions: + if tracer_args.rt_args is None: + rt_args_to_use = {} else: - invocation_info.status_message = "Model successfully built!" - invocation_info.status_message_color = printing.Colors.OKGREEN + rt_args_to_use = tracer_args.rt_args + + stats.save_model_eval_stat( + fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.RUNNING + ) + + model_handle = runtime_info["RuntimeClass"]( + cache_dir=tracer_args.cache_dir, + build_name=build_name, + stats=stats, + iterations=tracer_args.iterations, + model=model_to_benchmark, + inputs=inputs, + device_type=tracer_args.device, + runtime=selected_runtime, + **rt_args_to_use, + ) + perf = model_handle.benchmark() + + for key, value in vars(perf).items(): + stats.save_model_eval_stat( + key=key, + value=value, + ) + + stats.save_model_eval_stat( + fs.Keys.BENCHMARK_STATUS, fs.FunctionStatus.SUCCESSFUL + ) + + invocation_info.status_message = "Model successfully benchmarked!" + invocation_info.performance = perf + invocation_info.status_message_color = printing.Colors.OKGREEN + else: + invocation_info.status_message = "Model successfully built!" + invocation_info.status_message_color = printing.Colors.OKGREEN except exp.StageError as e: invocation_info.status_message = f"Build Error: {e}" invocation_info.status_message_color = printing.Colors.WARNING - stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED) + set_status_on_exception(build_state, stats) _store_traceback(invocation_info) except exp.SkipBuild: # SkipBuild is an exception that the build_model() API will raise # when it is skipping a previously-failed build when rebuild=never is set + + # NOTE: skipping a build should never update build or benchmark status + invocation_info.status_message = ( "Build intentionally skipped because rebuild=never" ) invocation_info.status_message_color = printing.Colors.WARNING - stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.KILLED) - except exp.ArgError as e: - # ArgError indicates that some argument to benchmark_model() was + # ArgError indicates that some argument to build_model() or BaseRT was # illegal. In that case we want to halt execution so that users can # fix their arguments. - stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED) + set_status_on_exception(build_state, stats) raise e @@ -292,7 +359,7 @@ def explore_invocation( invocation_info.status_message = f"Error: {e}." invocation_info.status_message_color = printing.Colors.WARNING - stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED) + set_status_on_exception(build_state, stats) _store_traceback(invocation_info) @@ -302,66 +369,19 @@ def explore_invocation( invocation_info.status_message = f"Unknown turnkey error: {e}" invocation_info.status_message_color = printing.Colors.WARNING - stats.save_model_eval_stat(fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.FAILED) + set_status_on_exception(build_state, stats) _store_traceback(invocation_info) - else: - # If there was no exception then we consider the build to be a success - stats.save_model_eval_stat( - fs.Keys.BENCHMARK_STATUS, fs.BenchmarkStatus.SUCCESSFUL - ) finally: # Ensure that stdout/stderr is not being forwarded before updating status util.stop_logger_forward() - system_info = build.get_system_info() - stats.save_model_stat( - fs.Keys.SYSTEM_INFO, - system_info, - ) - if model_info.model_type != build.ModelType.PYTORCH_COMPILED: - # We have this if-block because torch-compiled model instances - # are not legal input to this function. So when we encounter one, - # we want to exit the function as quickly as possible, without - # doing any of the logic that follows this comment. - - # ONNX stats that we want to save into the build's turnkey_stats.yaml file - # so that they can be easily accessed by the report command later - if fs.Keys.ONNX_FILE in stats.evaluation_stats.keys(): - # Just in case the ONNX file was generated on a different machine: - # strip the state's cache dir, then prepend the current cache dir - final_onnx_file = fs.rebase_cache_dir( - stats.evaluation_stats[fs.Keys.ONNX_FILE], - build_name, - tracer_args.cache_dir, - ) - - onnx_ops_counter = util.get_onnx_ops_list(final_onnx_file) - onnx_model_info = util.populate_onnx_model_info(final_onnx_file) - onnx_input_dimensions = util.onnx_input_dimensions(final_onnx_file) - - stats.save_model_stat( - fs.Keys.ONNX_OPS_COUNTER, - onnx_ops_counter, - ) - stats.save_model_stat( - fs.Keys.ONNX_MODEL_INFO, - onnx_model_info, - ) - stats.save_model_stat( - fs.Keys.ONNX_INPUT_DIMENSIONS, - onnx_input_dimensions, - ) - - if perf: - for key, value in vars(perf).items(): - stats.save_model_eval_stat( - key=key, - value=value, - ) + status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) - status.update(tracer_args.models_found, build_name, tracer_args.cache_dir) + if tracer_args.lean_cache: + printing.log_info("Removing build artifacts...") + fs.clean_output_dir(tracer_args.cache_dir, build_name) def get_model_hash( @@ -610,7 +630,7 @@ def forward_spy(*args, **kwargs): ) invocation_info.executed = invocation_info.executed + 1 - # Call benchmark_model() if this is the first time the model is being executed + # Call explore_invocation() if this is the first time the model is being executed # and this model has been selected by the user if ( invocation_info.executed == 1 @@ -623,7 +643,7 @@ def forward_spy(*args, **kwargs): invocation_info=invocation_info, tracer_args=tracer_args, ) - # Ensure that benchmark_model() doesn't interfere with our execution count + # Ensure that explore_invocation() doesn't interfere with our execution count model_info.executed = 1 build_name = fs.get_build_name( @@ -795,7 +815,7 @@ def evaluate_script(tracer_args: TracerArgs) -> Dict[str, util.ModelInfo]: "torch.jit.script(", "torch.jit.script() is not supported by turnkey CLI and benchmark_files() API, " "however torch.jit.script() is being called in your script." - "You can try passing your model instance into the benchmark_model() API instead. ", + "You can try passing your model instance into the build_model() API instead. ", ) ] ): diff --git a/src/turnkeyml/analyze/util.py b/src/turnkeyml/analyze/util.py index 27e96594..305af95c 100644 --- a/src/turnkeyml/analyze/util.py +++ b/src/turnkeyml/analyze/util.py @@ -8,7 +8,7 @@ from turnkeyml.common import printing import turnkeyml.common.build as build from turnkeyml.common.performance import MeasuredPerformance -from turnkeyml.common.filesystem import Stats +import turnkeyml.common.filesystem as fs class AnalysisException(Exception): @@ -37,7 +37,7 @@ class UniqueInvocationInfo: status_message_color: printing.Colors = printing.Colors.ENDC traceback_message_color: printing.Colors = printing.Colors.FAIL stats_keys: Optional[List[str]] = None - stats: Stats = None + stats: fs.Stats = None @dataclass @@ -162,3 +162,33 @@ def stop_logger_forward() -> None: sys.stdout = sys.stdout.terminal if hasattr(sys.stderr, "terminal_err"): sys.stderr = sys.stderr.terminal_err + + +def analyze_onnx(build_name: str, cache_dir: str, stats: fs.Stats): + # ONNX stats that we want to save into the build's turnkey_stats.yaml file + # so that they can be easily accessed by the report command later + if fs.Keys.ONNX_FILE in stats.evaluation_stats.keys(): + # Just in case the ONNX file was generated on a different machine: + # strip the state's cache dir, then prepend the current cache dir + final_onnx_file = fs.rebase_cache_dir( + stats.evaluation_stats[fs.Keys.ONNX_FILE], + build_name, + cache_dir, + ) + + onnx_ops_counter = get_onnx_ops_list(final_onnx_file) + onnx_model_info = populate_onnx_model_info(final_onnx_file) + input_dimensions = onnx_input_dimensions(final_onnx_file) + + stats.save_model_stat( + fs.Keys.ONNX_OPS_COUNTER, + onnx_ops_counter, + ) + stats.save_model_stat( + fs.Keys.ONNX_MODEL_INFO, + onnx_model_info, + ) + stats.save_model_stat( + fs.Keys.ONNX_INPUT_DIMENSIONS, + input_dimensions, + ) diff --git a/src/turnkeyml/build/ignition.py b/src/turnkeyml/build/ignition.py index 11184a12..86c51c94 100644 --- a/src/turnkeyml/build/ignition.py +++ b/src/turnkeyml/build/ignition.py @@ -222,9 +222,16 @@ def _begin_fresh_build( # start with a fresh State. stats = filesystem.Stats(state_args["cache_dir"], state_args["config"].build_name) + build_dir = build.output_dir( + state_args["cache_dir"], state_args["config"].build_name + ) + filesystem.rmdir( - build.output_dir(state_args["cache_dir"], state_args["config"].build_name), - exclude=stats.file, + build_dir, + excludes=[ + stats.file, + os.path.join(build_dir, filesystem.BUILD_MARKER), + ], ) state = state_type(**state_args) state.save() diff --git a/src/turnkeyml/build_api.py b/src/turnkeyml/build_api.py index 349044e9..b9a6aed6 100644 --- a/src/turnkeyml/build_api.py +++ b/src/turnkeyml/build_api.py @@ -118,19 +118,8 @@ def build_model( sequence_locked.show_monitor(config, state.monitor) state = sequence_locked.launch(state) - if state.build_status == build.Status.SUCCESSFUL_BUILD: - printing.log_success( - f"\n Saved to **{build.output_dir(state.cache_dir, config.build_name)}**" - ) - - return state + printing.log_success( + f"\n Saved to **{build.output_dir(state.cache_dir, config.build_name)}**" + ) - else: - printing.log_success( - f"Build Sequence {sequence_locked.unique_name} completed successfully" - ) - msg = """ - build_model() only returns a Model instance if the Sequence includes a Stage - that sets state.build_status=turnkey.common.build.Status.SUCCESSFUL_BUILD. - """ - printing.log_warning(msg) + return state diff --git a/src/turnkeyml/cli/report.py b/src/turnkeyml/cli/report.py index ec49a9fc..2c2f4281 100644 --- a/src/turnkeyml/cli/report.py +++ b/src/turnkeyml/cli/report.py @@ -84,14 +84,14 @@ def summary_spreadsheets(args) -> None: for subkey, subvalue in value.items(): evaluation_stats[subkey] = subvalue - # If a build is still marked as "running" at reporting time, it + # If a build or benchmark is still marked as "running" at + # reporting time, it # must have been killed by a time out, out-of-memory (OOM), or some # other uncaught exception if ( - key == fs.Keys.BENCHMARK_STATUS - and value == fs.BenchmarkStatus.RUNNING - ): - value = fs.BenchmarkStatus.KILLED + key == fs.Keys.BUILD_STATUS or fs.Keys.BENCHMARK_STATUS + ) and value == fs.FunctionStatus.RUNNING: + value = fs.FunctionStatus.KILLED evaluation_stats[key] = value diff --git a/src/turnkeyml/common/filesystem.py b/src/turnkeyml/common/filesystem.py index 961cb6bd..d2190336 100644 --- a/src/turnkeyml/common/filesystem.py +++ b/src/turnkeyml/common/filesystem.py @@ -39,22 +39,29 @@ MODELS_DIR = importlib.util.find_spec("turnkeyml_models").submodule_search_locations[0] -def rmdir(folder, exclude: Optional[str] = None): +def rmdir(folder, excludes: Optional[List[str]] = None): """ Remove the contents of a directory from the filesystem. - If `exclude=`, the directory itself and the file named + If `` is in `excludes`, the directory itself and the file named are kept. Otherwise, the entire directory is removed. """ + + # Use an empty list by default + if excludes: + excludes_to_use = excludes + else: + excludes_to_use = [] + if os.path.isdir(folder): for filename in os.listdir(folder): file_path = os.path.join(folder, filename) - if file_path != exclude: + if file_path not in excludes_to_use: if os.path.isfile(file_path) or os.path.islink(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) - if exclude is None: + if excludes is None: shutil.rmtree(folder) return True @@ -347,11 +354,13 @@ class Keys: SYSTEM_INFO = "system_info" # Path to the built-in model script used as input MODEL_SCRIPT = "builtin_model_script" - # Indicates a benchmark's status: running, successful, failed, or killed + # Indicates status of the most recent build tool run: FunctionStatus + BUILD_STATUS = "build_status" + # Indicates status of the most recent benchmark tool run: FunctionStatus BENCHMARK_STATUS = "benchmark_status" -class BenchmarkStatus: +class FunctionStatus: RUNNING = "running" SUCCESSFUL = "successful" FAILED = "failed" diff --git a/src/turnkeyml/model_api.py b/src/turnkeyml/model_api.py deleted file mode 100644 index d7cb155e..00000000 --- a/src/turnkeyml/model_api.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Dict, Optional, Union, List -from turnkeyml.build_api import build_model -from turnkeyml.build.stage import Sequence -import turnkeyml.common.printing as printing -import turnkeyml.common.filesystem as filesystem -from turnkeyml.common.performance import MeasuredPerformance -from turnkeyml.run.devices import ( - SUPPORTED_DEVICES, - SUPPORTED_RUNTIMES, - DEVICE_RUNTIME_MAP, - apply_default_runtime, -) -import turnkeyml.build.sequences as sequences -import turnkeyml.common.exceptions as exp - -TURNKEY_DEFAULT_REBUILD_POLICY = "if_needed" - - -def benchmark_model( - model: Any, - inputs: Dict[str, Any], - build_name: str, - iterations: int = 100, - evaluation_id: str = "build", - cache_dir: str = filesystem.DEFAULT_CACHE_DIR, - device: str = "x86", - runtime: Optional[str] = None, - build_only: bool = False, - lean_cache: bool = False, - rebuild: str = TURNKEY_DEFAULT_REBUILD_POLICY, - onnx_opset: Optional[int] = None, - sequence: Sequence = None, - rt_args: Optional[Dict[str, Union[str, List[str]]]] = None, -) -> MeasuredPerformance: - """ - Benchmark a model against some inputs on target hardware - """ - - selected_runtime = apply_default_runtime(device, runtime) - - # Build and benchmark the model - try: - # Validate device and runtime selections - if device not in SUPPORTED_DEVICES: - raise exp.ArgError( - f"Device argument '{device}' is not one of the available " - f"supported devices {SUPPORTED_DEVICES}\n" - f"You may need to check the spelling of '{device}', install a " - "plugin, or update the turnkeyml package." - ) - else: - if selected_runtime not in DEVICE_RUNTIME_MAP[device]: - raise exp.ArgError( - f"Runtime argument '{selected_runtime}' is not one of the available " - f"runtimes supported for device '{device}': {DEVICE_RUNTIME_MAP[device]}\n" - f"You may need to check the spelling of '{selected_runtime}', install a " - "plugin, or update the turnkeyml package." - ) - - # Get the plugin module for the selected runtime - runtime_info = SUPPORTED_RUNTIMES[selected_runtime] - - # Perform a build, if necessary - if runtime_info["build_required"]: - # Get the build sequence that will be used for the model - if sequence is None: - # Automatically choose a Sequence based on what the runtime expects - sequence_selected = runtime_info["default_sequence"] - else: - # User-specified Sequence - if isinstance(sequence, str): - # Sequence is defined by a plugin - if sequence in sequences.SUPPORTED_SEQUENCES.keys(): - sequence_selected = sequences.SUPPORTED_SEQUENCES[sequence] - else: - raise ValueError( - f"Sequence argument {sequence} is not one of the " - "available sequences installed: " - f"{sequences.SUPPORTED_SEQUENCES.keys()} \n" - f"You may need to check the spelling of `{sequence}`, " - "install a plugin, or update the turnkeyml package." - ) - - elif isinstance(sequence, Sequence): - # Sequence is a user-defined instance of Sequence - sequence_selected = sequence - - build_model( - model=model, - inputs=inputs, - evaluation_id=evaluation_id, - build_name=build_name, - cache_dir=cache_dir, - rebuild=rebuild, - sequence=sequence_selected, - onnx_opset=onnx_opset, - device=device, - ) - - # Perform benchmarking, if requested - if not build_only: - if rt_args is None: - rt_args_to_use = {} - else: - rt_args_to_use = rt_args - - printing.log_info(f"Benchmarking on {device}...") - stats = filesystem.Stats(cache_dir, build_name, evaluation_id) - model_handle = runtime_info["RuntimeClass"]( - cache_dir=cache_dir, - build_name=build_name, - stats=stats, - iterations=iterations, - model=model, - inputs=inputs, - device_type=device, - runtime=selected_runtime, - **rt_args_to_use, - ) - perf = model_handle.benchmark() - - finally: - # Make sure the build and cache dirs exist and have the proper marker files - # NOTE: We would do this at the top of the file, however - # there are conditions where build_model() will wipe the build dir, - # which would eliminate our marker file - filesystem.make_build_dir(cache_dir, build_name) - - # Clean cache if needed - if lean_cache: - printing.log_info("Removing build artifacts...") - filesystem.clean_output_dir(cache_dir, build_name) - - if not build_only: - perf.print() - return perf - else: - return None diff --git a/src/turnkeyml/run/devices.py b/src/turnkeyml/run/devices.py index a5411f37..c48d81b5 100644 --- a/src/turnkeyml/run/devices.py +++ b/src/turnkeyml/run/devices.py @@ -1,9 +1,12 @@ from typing import Optional -from typing import List, Dict +from typing import List, Dict, Tuple import turnkeyml.run.onnxrt as onnxrt import turnkeyml.run.tensorrt as tensorrt import turnkeyml.run.torchrt as torchrt import turnkeyml.common.plugins as plugins +from turnkeyml.build.stage import Sequence +import turnkeyml.build.sequences as sequences +import turnkeyml.common.exceptions as exp def supported_devices_list(data: Dict, parent_key: str = "") -> List: @@ -72,3 +75,63 @@ def apply_default_runtime(device: str, runtime: Optional[str] = None): return DEVICE_RUNTIME_MAP[device][DEFAULT_RUNTIME] else: return runtime + + +def _check_suggestion(value: str): + return ( + f"You may need to check the spelling of '{value}', install a " + "plugin, or update the turnkeyml package." + ) + + +def select_runtime_and_sequence( + device: str, runtime: Optional[str], sequence: Optional[Sequence] +) -> Tuple[str, str, Sequence]: + selected_runtime = apply_default_runtime(device, runtime) + + # Validate device and runtime selections + if device not in SUPPORTED_DEVICES: + raise exp.ArgError( + f"Device argument '{device}' is not one of the available " + f"supported devices {SUPPORTED_DEVICES}\n" + f"{_check_suggestion(device)}" + ) + if selected_runtime not in DEVICE_RUNTIME_MAP[device]: + raise exp.ArgError( + f"Runtime argument '{selected_runtime}' is not one of the available " + f"runtimes supported for device '{device}': {DEVICE_RUNTIME_MAP[device]}\n" + f"{_check_suggestion(selected_runtime)}" + ) + + # Get the plugin module for the selected runtime + runtime_info = SUPPORTED_RUNTIMES[selected_runtime] + + # Perform a build, if necessary + if runtime_info["build_required"]: + # Get the build sequence that will be used for the model + if sequence is None: + # Automatically choose a Sequence based on what the runtime expects + sequence_selected = runtime_info["default_sequence"] + else: + # User-specified Sequence + if isinstance(sequence, str): + # Sequence is defined by a plugin + if sequence in sequences.SUPPORTED_SEQUENCES.keys(): + sequence_selected = sequences.SUPPORTED_SEQUENCES[sequence] + else: + raise ValueError( + f"Sequence argument {sequence} is not one of the " + "available sequences installed: " + f"{sequences.SUPPORTED_SEQUENCES.keys()} \n" + f"{_check_suggestion(sequence)}" + ) + + elif isinstance(sequence, Sequence): + # Sequence is a user-defined instance of Sequence + sequence_selected = sequence + + else: + # Sequence is only needed for builds + sequence_selected = None + + return selected_runtime, runtime_info, sequence_selected diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index 493f7415..6a9beea8 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "0.3.0" +__version__ = "0.4.0" diff --git a/test/cli.py b/test/cli.py index e555bc7f..d5731b23 100644 --- a/test/cli.py +++ b/test/cli.py @@ -389,7 +389,7 @@ def test_005_cli_list(self): for test_script in common.test_scripts_dot_py.keys(): script_name = common.strip_dot_py(test_script) - assert script_name in f.getvalue() + assert script_name in f.getvalue(), f"{script_name} {f.getvalue()}" def test_006_cli_delete(self): # NOTE: this is not a unit test, it relies on other command @@ -976,6 +976,7 @@ def test_028_cli_timeout(self): "--process-isolation", "--timeout", "10", + "--build-only", ] with patch.object(sys, "argv", flatten(testargs)): turnkeycli() @@ -1000,8 +1001,8 @@ def test_028_cli_timeout(self): try: timeout_summary = summary[0] - assert timeout_summary["benchmark_status"] == "killed", timeout_summary[ - "benchmark_status" + assert timeout_summary["build_status"] == "killed", timeout_summary[ + "build_status" ] except IndexError: # Edge case where the CSV is empty because the build timed out before diff --git a/test/model_api.py b/test/model_api.py deleted file mode 100644 index 6822da93..00000000 --- a/test/model_api.py +++ /dev/null @@ -1,171 +0,0 @@ -import os -import unittest -import torch -import shutil -import onnx -import platform -import turnkeyml.build.stage as stage -import turnkeyml.common.filesystem as filesystem -import turnkeyml.build.export as export -import turnkeyml.common.build as build -from turnkeyml import benchmark_model -from helpers import common - - -class SmallPytorchModel(torch.nn.Module): - def __init__(self): - super(SmallPytorchModel, self).__init__() - self.fc = torch.nn.Linear(10, 5) - - def forward(self, x): - output = self.fc(x) - return output - - -class AnotherSimplePytorchModel(torch.nn.Module): - def __init__(self): - super(AnotherSimplePytorchModel, self).__init__() - self.relu = torch.nn.ReLU() - - def forward(self, x): - output = self.relu(x) - return output - - -# Define pytorch model and inputs -pytorch_model = SmallPytorchModel() -tiny_pytorch_model = AnotherSimplePytorchModel() -inputs = {"x": torch.rand(10)} -inputs_2 = {"x": torch.rand(5)} -input_tensor = torch.rand(10) - -# Create a test directory -cache_dir, _ = common.create_test_dir("cli") - - -def get_build_state(cache_dir, build_name): - return build.load_state(cache_dir=cache_dir, build_name=build_name) - - -class Testing(unittest.TestCase): - def setUp(self) -> None: - filesystem.rmdir(cache_dir) - return super().setUp() - - def test_001_build_pytorch_model(self): - build_name = "build_pytorch_model" - benchmark_model( - pytorch_model, - inputs, - build_name=build_name, - rebuild="always", - build_only=True, - cache_dir=cache_dir, - runtime="ort", - ) - state = get_build_state(cache_dir, build_name) - assert state.build_status == build.Status.SUCCESSFUL_BUILD - - def test_002_custom_stage(self): - build_name = "custom_stage" - - class MyCustomStage(stage.Stage): - def __init__(self, funny_saying): - super().__init__( - unique_name="funny_stage", - monitor_message="Funny Stage", - ) - - self.funny_saying = funny_saying - - def fire(self, state): - print(f"funny message: {self.funny_saying}") - state.build_status = build.Status.SUCCESSFUL_BUILD - return state - - my_custom_stage = MyCustomStage( - funny_saying="Is a fail whale a fail at all if it makes you smile?" - ) - my_sequence = stage.Sequence( - unique_name="my_sequence", - monitor_message="Running My Sequence", - stages=[ - export.ExportPytorchModel(), - export.OptimizeOnnxModel(), - my_custom_stage, - ], - ) - - benchmark_model( - pytorch_model, - inputs, - build_name=build_name, - rebuild="always", - sequence=my_sequence, - build_only=True, - cache_dir=cache_dir, - runtime="ort", - ) - - state = get_build_state(cache_dir, build_name) - return state.build_status == build.Status.SUCCESSFUL_BUILD - - # TODO: Investigate why this test is only failing on Windows CI failing - @unittest.skipIf(platform.system() == "Windows", "Windows CI only failure") - def test_003_local_benchmark(self): - build_name = "local_benchmark" - perf = benchmark_model( - pytorch_model, - inputs, - device="x86", - build_name=build_name, - rebuild="always", - cache_dir=cache_dir, - lean_cache=True, - runtime="ort", - ) - state = get_build_state(cache_dir, build_name) - assert state.build_status == build.Status.SUCCESSFUL_BUILD - assert os.path.isfile( - os.path.join(cache_dir, build_name, "x86_benchmark/outputs.json") - ) - assert perf.mean_latency > 0 - assert perf.throughput > 0 - - # TODO: Investigate why this test is only failing on Windows CI failing - @unittest.skipIf(platform.system() == "Windows", "Windows CI only issue") - def test_004_onnx_opset(self): - """ - Make sure we can successfully benchmark a model with a user-defined ONNX opset - """ - - build_name = "onnx_opset" - - user_opset = 15 - assert user_opset != build.DEFAULT_ONNX_OPSET - - perf = benchmark_model( - pytorch_model, - inputs, - device="x86", - build_name=build_name, - rebuild="always", - cache_dir=cache_dir, - onnx_opset=user_opset, - runtime="ort", - ) - state = get_build_state(cache_dir, build_name) - assert state.build_status == build.Status.SUCCESSFUL_BUILD - assert os.path.isfile( - os.path.join(cache_dir, build_name, "x86_benchmark/outputs.json") - ) - assert perf.mean_latency > 0 - assert perf.throughput > 0 - - onnx_model = onnx.load(state.results[0]) - model_opset = getattr(onnx_model.opset_import[0], "version", None) - assert user_opset == model_opset - - -if __name__ == "__main__": - unittest.main()