Skip to content

Commit

Permalink
Optimize LazyValues and SparseValues with Caching Mechanism
Browse files Browse the repository at this point in the history
Signed-off-by: Phoenix <861062923@qq.com>
  • Loading branch information
Phoenix8215 committed Sep 21, 2024
1 parent c56f4c7 commit ff143f1
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions tools/onnx-graphsurgeon/onnx_graphsurgeon/ir/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,18 @@ def __init__(self, tensor):
self.shape = get_onnx_tensor_shape(self.tensor)
self.dtype = get_onnx_tensor_dtype(self.tensor)
self.nbytes = misc.volume(self.shape) * get_itemsize(self.dtype)
self._cached_values = None # Initialize the cache

def load(self):
"""
Load a numpy array from the underlying tensor values.
Load a numpy array from the underlying tensor values, using cache.
Returns:
np.array: A numpy array containing the values of the tensor.
"""
if self._cached_values is not None:
return self._cached_values # Return cached data if available

import onnx
import onnx.numpy_helper
from onnx_graphsurgeon.importers.onnx_importer import (
Expand All @@ -254,7 +258,8 @@ def load(self):
f"If this is not what you intended, please avoid accessing the values of this constant tensor."
)

return np.array(onnx.numpy_helper.to_array(self.tensor))
self._cached_values = np.array(onnx.numpy_helper.to_array(self.tensor))
return self._cached_values

def __str__(self):
return "LazyValues (shape={:}, dtype={:})".format(self.shape, self.dtype)
Expand All @@ -268,13 +273,20 @@ class SparseValues(LazyValues):
A special object that represents constant tensor values that is sparse
"""

def __init__(self, tensor):
super().__init__(tensor)
self._cached_values = None # Initialize the cache

def load(self):
"""
Load a numpy array from the sparse structure.
Load a numpy array from the sparse structure, using cache.
Returns:
np.array: A numpy array containing the values of the tensor.
"""
if self._cached_values is not None:
return self._cached_values # Return cached data if available

import onnx
import onnx.numpy_helper
from onnx_graphsurgeon.importers.onnx_importer import (
Expand Down Expand Up @@ -316,7 +328,8 @@ def load(self):
f"Unsupported index data dims {self.tensor.indices.dims} in {self.tensor.values.name}"
)

return values
self._cached_values = values
return self._cached_values

def __str__(self):
return "SparseValues (shape={:}, dtype={:})".format(self.shape, self.dtype)
Expand Down

0 comments on commit ff143f1

Please sign in to comment.