diff --git a/examples/rcn.py b/examples/rcn.py index 7768f3b1..3a37bbb7 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -39,7 +39,7 @@ # Use train_size = 100 if you have a gpu with atleast 8Gbs memory. # Recommend that jax is installed with cuda enabled for this option. -train_size = 100 +train_size = 20 test_size = 20 @@ -98,7 +98,9 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5): # # 2. Load the model # %% -data = np.load("example_data/rcn_100.npz", allow_pickle=True, encoding="latin1") +data = np.load( + f"example_data/rcn_{train_size}.npz", allow_pickle=True, encoding="latin1" +) frcs, edges, suppression_masks, filters = ( data["frcs"][:train_size], data["edges"][:train_size],