diff --git a/src/turnkeyml/build/export.py b/src/turnkeyml/build/export.py index 945af010..092695cf 100644 --- a/src/turnkeyml/build/export.py +++ b/src/turnkeyml/build/export.py @@ -280,18 +280,24 @@ def fire(self, state: build.State): default_warnings = warnings.showwarning warnings.showwarning = _warn_to_stdout - stats = fs.Stats( - state.cache_dir, state.config.build_name, state.evaluation_id - ) + stats = fs.Stats(state.cache_dir, state.config.build_name, state.evaluation_id) # Verify if the exported model matches the input torch model try: + # Tolerance levels for the torch export are recommended by Pytorch here: + # https://pytorch.org/docs/stable/testing.html#module-torch.testing + fp32_tolerance = torch.onnx.verification.VerificationOptions( + rtol=1.3e-6, atol=1e-5 + ) + # The `torch.onnx.verification.find_mismatch()` takes input arguments to the # model as `input_args (Tuple[Any, ...])` export_verification = torch.onnx.verification.find_mismatch( state.model, tuple(state.inputs.values()), - opset_version=state.config.onnx_opset) + opset_version=state.config.onnx_opset, + options=fp32_tolerance, + ) # `export_verification.has_mismatch()` returns True if a mismatch is found and # False otherwise. If no mismatch is found,# `is_export_valid` is set to "Valid", @@ -311,8 +317,8 @@ def fire(self, state: build.State): is_export_valid = "unverified" stats.save_model_eval_stat( - fs.Keys.TORCH_ONNX_EXPORT_VALIDITY, - is_export_valid, + fs.Keys.TORCH_ONNX_EXPORT_VALIDITY, + is_export_valid, ) # Export the model to ONNX