From f4772152f0aaa2178504e58c780a94c02198afae Mon Sep 17 00:00:00 2001 From: Anselm Hahn Date: Sat, 14 May 2022 08:31:49 +0200 Subject: [PATCH] Fixed: #1722 Run out of memory Update utils.py and test by using `break` instead of `raise` --- autokeras/utils/utils.py | 9 +++++++-- autokeras/utils/utils_test.py | 21 +++++++++++---------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/autokeras/utils/utils.py b/autokeras/utils/utils.py index 9552b9e54..58b15f130 100644 --- a/autokeras/utils/utils.py +++ b/autokeras/utils/utils.py @@ -94,15 +94,20 @@ def fit_with_adaptive_batch_size(model, batch_size, **fit_kwargs): def run_with_adaptive_batch_size(batch_size, func, **fit_kwargs): x = fit_kwargs.pop("x") validation_data = None + history = None if "validation_data" in fit_kwargs: validation_data = fit_kwargs.pop("validation_data") while batch_size > 0: try: history = func(x=x, validation_data=validation_data, **fit_kwargs) break - except tf.errors.ResourceExhaustedError as e: + except tf.errors.ResourceExhaustedError: if batch_size == 1: - raise e + print( + "Not enough memory, reduced batch size is already set to 1. " + "Current model will be skipped." + ) + break batch_size //= 2 print( "Not enough memory, reduce batch size to {batch_size}.".format( diff --git a/autokeras/utils/utils_test.py b/autokeras/utils/utils_test.py index 0b3acfcc4..12234beb2 100644 --- a/autokeras/utils/utils_test.py +++ b/autokeras/utils/utils_test.py @@ -53,19 +53,20 @@ def test_check_kt_version_error(): ) -def test_run_with_adaptive_batch_size_raise_error(): +def test_run_with_adaptive_batch_size_raise_error(capfd): def func(**kwargs): raise tf.errors.ResourceExhaustedError(0, "", None) - with pytest.raises(tf.errors.ResourceExhaustedError): - utils.run_with_adaptive_batch_size( - batch_size=64, - func=func, - x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64), - validation_data=tf.data.Dataset.from_tensor_slices( - np.random.rand(100, 1) - ).batch(64), - ) + utils.run_with_adaptive_batch_size( + batch_size=64, + func=func, + x=tf.data.Dataset.from_tensor_slices(np.random.rand(100, 1)).batch(64), + validation_data=tf.data.Dataset.from_tensor_slices( + np.random.rand(100, 1) + ).batch(64), + ) + std, _ = capfd.readouterr() + assert "Not enough memory" in std def test_get_hyperparameter_with_none_return_hp():