From 97435b70bd3cdb2b444f53d79218c55d57b326e6 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 May 2020 20:22:00 -0400 Subject: [PATCH] Cache test_utils.format_shape_and_dtype_string. A significant fraction of time when collecting test cases is spent building shape and dtype strings (which are usually similar and usually thrown away.) --- jax/test_util.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/jax/test_util.py b/jax/test_util.py index b11e8730853a..ff9f82d79edf 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -425,17 +425,23 @@ def dtype_str(dtype): def format_shape_dtype_string(shape, dtype): + if isinstance(shape, onp.ndarray): + return f'{dtype_str(dtype)}[{shape}]' + elif isinstance(shape, list): + shape = tuple(shape) + return _format_shape_dtype_string(shape, dtype) + +@functools.lru_cache(maxsize=64) +def _format_shape_dtype_string(shape, dtype): if shape is NUMPY_SCALAR_SHAPE: return dtype_str(dtype) elif shape is PYTHON_SCALAR_SHAPE: return 'py' + dtype_str(dtype) - elif type(shape) in (list, tuple): + elif type(shape) is tuple: shapestr = ','.join(str(dim) for dim in shape) return '{}[{}]'.format(dtype_str(dtype), shapestr) elif type(shape) is int: return '{}[{},]'.format(dtype_str(dtype), shape) - elif isinstance(shape, onp.ndarray): - return '{}[{}]'.format(dtype_str(dtype), shape) else: raise TypeError(type(shape))