From fda1bd9a24693d61ca4bd00092a6c4c9ebc4ecab Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Fri, 23 Feb 2024 20:31:51 +0000 Subject: [PATCH] external data fix --- src/sparsezoo/utils/onnx/external_data.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/sparsezoo/utils/onnx/external_data.py b/src/sparsezoo/utils/onnx/external_data.py index 636de3a9b..aca41dd47 100644 --- a/src/sparsezoo/utils/onnx/external_data.py +++ b/src/sparsezoo/utils/onnx/external_data.py @@ -76,6 +76,7 @@ def save_onnx( model_path: str, max_external_file_size: int = 16e9, external_data_file: Optional[str] = None, + do_split_external_data: bool = True, ) -> bool: """ Save model to the given path. @@ -95,6 +96,8 @@ def save_onnx( specified in the variable EXTERNAL_ONNX_DATA_NAME :param max_external_file_size: The maximum file size in bytes of a single split external data out file. Defaults to 16000000000 (16e9 = 16GB) + :param do_split_external_data: True to split external data file into chunks of max + size max_external_file_size, false otherwise :return True if the model was saved with external data, False otherwise. """ if external_data_file is not None: @@ -112,7 +115,8 @@ def save_onnx( all_tensors_to_one_file=True, location=external_data_file, ) - split_external_data(model_path, max_file_size=max_external_file_size) + if do_split_external_data: + split_external_data(model_path, max_file_size=max_external_file_size) return True if model.ByteSize() > DUMP_EXTERNAL_DATA_THRESHOLD: @@ -132,7 +136,8 @@ def save_onnx( all_tensors_to_one_file=True, location=external_data_file, ) - split_external_data(model_path, max_file_size=max_external_file_size) + if do_split_external_data: + split_external_data(model_path, max_file_size=max_external_file_size) return True onnx.save(model, model_path) @@ -247,6 +252,9 @@ def split_external_data( f"{external_data_file_path} not found. {model_path} must have external " "data written to a single file in the same directory" ) + if os.path.getsize(external_data_file_path) <= max_file_size: + # return immediately if file is small enough to not split + return # UPDATE: external data info of graph tensors so they point to the new split out # files with updated offsets @@ -300,14 +308,6 @@ def split_external_data( # WRITE - ONNX model with updated tensor external data info onnx.save(model, model_path) - # RENAME - if as a result of splitting we end up with a single file, rename it to - # the original external data file name - if current_external_data_file_number == 1: - os.rename( - os.path.join(base_dir, updated_file_name), - os.path.join(base_dir, external_data_file), - ) - def _write_external_data_file_from_base_bytes( new_file_name, original_byte_ranges, original_file_bytes_reader