diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e3934ec..e6c23138 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48). - Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101). - Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104). - Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93). diff --git a/examples/iMAML/README.md b/examples/iMAML/README.md new file mode 100644 index 00000000..91f95f69 --- /dev/null +++ b/examples/iMAML/README.md @@ -0,0 +1,18 @@ +# implicit MAML few-shot Omniglot classification-examples + +Code on implicit MAML few-shot Omniglot classification in paper [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) using TorchOpt. We use `torchopt.sgd` as the inner-loop optimizer. + +## Usage + +```bash +### Run +python3 imaml_omniglot.py --inner_steps 5 +``` + +## Results + +The figure illustrate the experimental result. + +
+ +
diff --git a/examples/iMAML/imaml-accs.png b/examples/iMAML/imaml-accs.png new file mode 100644 index 00000000..1d5134a4 Binary files /dev/null and b/examples/iMAML/imaml-accs.png differ diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py new file mode 100644 index 00000000..7b165ac0 --- /dev/null +++ b/examples/iMAML/imaml_omniglot.py @@ -0,0 +1,341 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details) +for few-shot Omniglot classification. + +[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019). + Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124). + https://arxiv.org/abs/1909.04630 +""" + +import argparse +import time + +import functorch +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import torchopt +from torchopt import pytree + + +from support.omniglot_loaders import OmniglotNShot # isort: skip + + +mpl.use('Agg') +plt.style.use('bmh') + + +def main(): + argparser = argparse.ArgumentParser() + argparser.add_argument('--n_way', type=int, help='n way', default=5) + argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5) + argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5) + argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5) + argparser.add_argument( + '--reg_params', type=float, help='regularization parameters', default=2.0 + ) + argparser.add_argument( + '--task_num', type=int, help='meta batch size, namely task num', default=16 + ) + argparser.add_argument('--seed', type=int, help='random seed', default=1) + args = argparser.parse_args() + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + np.random.seed(args.seed) + rng = np.random.default_rng(args.seed) + + # Set up the Omniglot loader. + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + db = OmniglotNShot( + '/tmp/omniglot-data', + batchsz=args.task_num, + n_way=args.n_way, + k_shot=args.k_spt, + k_query=args.k_qry, + imgsz=28, + rng=rng, + device=device, + ) + + # Create a vanilla PyTorch neural network. + net = nn.Sequential( + nn.Conv2d(1, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Conv2d(64, 64, 3), + nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False), + nn.ReLU(inplace=False), + nn.MaxPool2d(2, 2), + nn.Flatten(), + nn.Linear(64, args.n_way), + ).to(device) + + # We will use Adam to (meta-)optimize the initial parameters + # to be adapted. + net.train() + fnet, params = functorch.make_functional(net) + meta_opt = torchopt.adam(lr=1e-3) + meta_opt_state = meta_opt.init(params) + + log = [] + test(db, [params, fnet], epoch=-1, log=log, args=args) + for epoch in range(10): + meta_opt, meta_opt_state = train( + db, [params, fnet], (meta_opt, meta_opt_state), epoch, log, args + ) + test(db, [params, fnet], epoch, log, args) + plot(log) + + +def train(db, net, meta_opt_and_state, epoch, log, args): + n_train_iter = db.x_train.shape[0] // db.batchsz + params, fnet = net + meta_opt, meta_opt_state = meta_opt_and_state + # Given this module we've created, rip out the parameters and buffers + # and return a functional version of the module. `fnet` is stateless + # and can be called with `fnet(params, buffers, args, kwargs)` + # fnet, params, buffers = functorch.make_functional_with_buffers(net) + + for batch_idx in range(n_train_iter): + start_time = time.time() + # Sample a batch of support and query images and labels. + x_spt, y_spt, x_qry, y_qry = db.next() + + task_num, setsz, c_, h, w = x_spt.size() + querysz = x_qry.size(1) + + n_inner_iter = args.inner_steps + reg_param = args.reg_params + qry_losses = [] + qry_accs = [] + + init_params_copy = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params + ) + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + optimal_params = train_imaml_inner_solver( + init_params_copy, + params, + (x_spt[i], y_spt[i]), + (fnet, n_inner_iter, reg_param), + ) + # The final set of adapted parameters will induce some + # final loss and accuracy on the query dataset. + # These will be used to update the model's meta-parameters. + qry_logits = fnet(optimal_params, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i]) + # Update the model's meta-parameters to optimize the query + # losses across all of the tasks sampled in this batch. + # qry_loss = qry_loss / task_num # scale gradients + meta_grads = torch.autograd.grad(qry_loss / task_num, params) + meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state) + params = torchopt.apply_updates(params, meta_updates) + qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz + + qry_losses.append(qry_loss.detach()) + qry_accs.append(qry_acc) + + qry_losses = sum(qry_losses) / task_num + qry_accs = 100.0 * sum(qry_accs) / task_num + i = epoch + float(batch_idx) / n_train_iter + iter_time = time.time() - start_time + + print( + f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}' + ) + + log.append( + { + 'epoch': i, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'train', + 'time': time.time(), + } + ) + + return (meta_opt, meta_opt_state) + + +def test(db, net, epoch, log, args): + # Crucially in our testing procedure here, we do *not* fine-tune + # the model during testing for simplicity. + # Most research papers using MAML for this task do an extra + # stage of fine-tuning here that should be added if you are + # adapting this code for research. + params, fnet = net + # fnet, params, buffers = functorch.make_functional_with_buffers(net) + n_test_iter = db.x_test.shape[0] // db.batchsz + + qry_losses = [] + qry_accs = [] + + # TODO: Maybe pull this out into a separate module so it + # doesn't have to be duplicated between `train` and `test`? + n_inner_iter = args.inner_steps + reg_param = args.reg_params + init_params_copy = pytree.tree_map( + lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params + ) + + for batch_idx in range(n_test_iter): + x_spt, y_spt, x_qry, y_qry = db.next('test') + + task_num, setsz, c_, h, w = x_spt.size() + querysz = x_qry.size(1) + + for i in range(task_num): + # Optimize the likelihood of the support set by taking + # gradient steps w.r.t. the model's parameters. + # This adapts the model's meta-parameters to the task. + + optimal_params = test_imaml_inner_solver( + init_params_copy, + params, + (x_spt[i], y_spt[i]), + (fnet, n_inner_iter, reg_param), + ) + + # The query loss and acc induced by these parameters. + qry_logits = fnet(optimal_params, x_qry[i]) + qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none') + qry_losses.append(qry_loss.detach()) + qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach()) + + qry_losses = torch.cat(qry_losses).mean().item() + qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item() + print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}') + log.append( + { + 'epoch': epoch + 1, + 'loss': qry_losses, + 'acc': qry_accs, + 'mode': 'test', + 'time': time.time(), + } + ) + + +def imaml_objective(optimal_params, init_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + fnet.eval() + y_pred = fnet(optimal_params, x_spt) + fnet.train() + regularization_loss = 0 + for p1, p2 in zip(optimal_params, init_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + loss = F.cross_entropy(y_pred, y_spt) + regularization_loss + return loss + + +@torchopt.diff.implicit.custom_root( + functorch.grad(imaml_objective, argnums=0), + argnums=1, + has_aux=False, + solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0), +) +def train_imaml_inner_solver(init_params_copy, init_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + # Initial functional optimizer based on TorchOpt + params = init_params_copy + inner_opt = torchopt.sgd(lr=1e-1) + inner_opt_state = inner_opt.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(n_inner_iter): + pred = fnet(params, x_spt) + loss = F.cross_entropy(pred, y_spt) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, init_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates + params = torchopt.apply_updates(params, updates) + return params + + +def test_imaml_inner_solver(init_params_copy, init_params, data, aux): + x_spt, y_spt = data + fnet, n_inner_iter, reg_param = aux + # Initial functional optimizer based on TorchOpt + params = init_params_copy + inner_opt = torchopt.sgd(lr=1e-1) + inner_opt_state = inner_opt.init(params) + with torch.enable_grad(): + # Temporarily enable gradient computation for conducting the optimization + for _ in range(n_inner_iter): + pred = fnet(params, x_spt) + loss = F.cross_entropy(pred, y_spt) # compute loss + # Compute regularization loss + regularization_loss = 0 + for p1, p2 in zip(params, init_params): + regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2)) + final_loss = loss + regularization_loss + grads = torch.autograd.grad(final_loss, params) # compute gradients + updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates + params = torchopt.apply_updates(params, updates) + return params + + +def plot(log): + # Generally you should pull your plotting code out of your training + # script but we are doing it here for brevity. + df = pd.DataFrame(log) + + fig, ax = plt.subplots(figsize=(8, 4), dpi=250) + train_df = df[df['mode'] == 'train'] + test_df = df[df['mode'] == 'test'] + ax.plot(train_df['epoch'], train_df['acc'], label='Train') + ax.plot(test_df['epoch'], test_df['acc'], label='Test') + ax.set_xlabel('Epoch') + ax.set_ylabel('Accuracy') + ax.set_ylim(80, 100) + ax.set_title('iMAML Omniglot') + ax.legend(ncol=2, loc='lower right') + fig.tight_layout() + fname = 'imaml-accs.png' + print(f'--- Plotting accuracy to {fname}') + fig.savefig(fname) + plt.close(fig) + + +if __name__ == '__main__': + main() diff --git a/examples/iMAML/support/omniglot_loaders.py b/examples/iMAML/support/omniglot_loaders.py new file mode 100644 index 00000000..d857d386 --- /dev/null +++ b/examples/iMAML/support/omniglot_loaders.py @@ -0,0 +1,327 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation: +# https://github.com/dragen1860/MAML-Pytorch +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py +# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py +# ============================================================================== + +import errno +import os + +import numpy as np +import torch +import torch.utils.data as data +import torchvision.transforms as transforms +from PIL import Image + + +class Omniglot(data.Dataset): + """ + The items are ``(filename, category)``. The index of all the categories can be found in + :attr:`idx_classes`. + + Args: + root: the directory where the dataset will be stored + transform: how to transform the input + target_transform: how to transform the target + download: need to download the dataset + """ + + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip', + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if not self._check_exists(): + if download: + self.download() + else: + raise RuntimeError('Dataset not found. You can use download=True to download it') + + self.all_items = find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes = index_classes(self.all_items) + + def __getitem__(self, index): + filename = self.all_items[index][0] + img = str.join('/', [self.all_items[index][2], filename]) + + target = self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists( + os.path.join(self.root, self.processed_folder, 'images_evaluation') + ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background')) + + def download(self): + import zipfile + + from six.moves import urllib + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print('== Unzip from ' + file_path + ' to ' + file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print('Download finished.') + + +def find_classes(root_dir): + retour = [] + for (root, dirs, files) in os.walk(root_dir): + for f in files: + if f.endswith('png'): + r = root.split('/') + lr = len(r) + retour.append((f, r[lr - 2] + '/' + r[lr - 1], root)) + print('== Found %d items ' % len(retour)) + return retour + + +def index_classes(items): + idx = {} + for i in items: + if i[1] not in idx: + idx[i[1]] = len(idx) + print('== Found %d classes' % len(idx)) + return idx + + +class OmniglotNShot: + def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None): + """ + Different from mnistNShot, the + :param root: + :param batchsz: task num + :param n_way: + :param k_shot: + :param k_qry: + :param imgsz: + """ + + self.resize = imgsz + self.rng = rng + self.device = device + if not os.path.isfile(os.path.join(root, 'omniglot.npy')): + # if root/data.npy does not exist, just download it + self.x = Omniglot( + root, + download=True, + transform=transforms.Compose( + [ + lambda x: Image.open(x).convert('L'), + lambda x: x.resize((imgsz, imgsz)), + lambda x: np.reshape(x, (imgsz, imgsz, 1)), + lambda x: np.transpose(x, [2, 0, 1]), + lambda x: x / 255.0, + ] + ), + ) + + # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total} + temp = {} + for (img, label) in self.x: + if label in temp.keys(): + temp[label].append(img) + else: + temp[label] = [img] + + self.x = [] + for ( + label, + imgs, + ) in temp.items(): # labels info deserted , each label contains 20imgs + self.x.append(np.array(imgs)) + + # as different class may have different number of imgs + self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total] + # each character contains 20 imgs + print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1] + temp = [] # Free memory + # save all dataset into npy file. + np.save(os.path.join(root, 'omniglot.npy'), self.x) + print('write into omniglot.npy.') + else: + # if data.npy exists, just load it. + self.x = np.load(os.path.join(root, 'omniglot.npy')) + print('load from omniglot.npy.') + + # [1623, 20, 84, 84, 1] + # TODO: can not shuffle here, we must keep training and test set distinct! + self.x_train, self.x_test = self.x[:1200], self.x[1200:] + + # self.normalization() + + self.batchsz = batchsz + self.n_cls = self.x.shape[0] # 1623 + self.n_way = n_way # n way + self.k_shot = k_shot # k shot + self.k_query = k_query # k query + assert (k_shot + k_query) <= 20 + + # save pointer of current read batch in total cache + self.indexes = {'train': 0, 'test': 0} + self.datasets = { + 'train': self.x_train, + 'test': self.x_test, + } # original data cached + print('DB: train', self.x_train.shape, 'test', self.x_test.shape) + + self.datasets_cache = { + 'train': self.load_data_cache(self.datasets['train']), # current epoch data cached + 'test': self.load_data_cache(self.datasets['test']), + } + + def normalization(self): + """ + Normalizes our data, to have a mean of 0 and sdt of 1 + """ + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + self.x_train = (self.x_train - self.mean) / self.std + self.x_test = (self.x_test - self.mean) / self.std + + self.mean = np.mean(self.x_train) + self.std = np.std(self.x_train) + self.max = np.max(self.x_train) + self.min = np.min(self.x_train) + + # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std) + + def load_data_cache(self, data_pack): + """ + Collects several batches data for N-shot learning + :param data_pack: [cls_num, 20, 84, 84, 1] + :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks + """ + + # take 5 way 1 shot as example: 5 * 1 + setsz = self.k_shot * self.n_way + querysz = self.k_query * self.n_way + data_cache = [] + + # print('preload next 50 caches of batchsz of batch.') + for sample in range(10): # num of episodes + + x_spts, y_spts, x_qrys, y_qrys = [], [], [], [] + for i in range(self.batchsz): # one batch means one set + + x_spt, y_spt, x_qry, y_qry = [], [], [], [] + selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False) + + for j, cur_class in enumerate(selected_cls): + + selected_img = self.rng.choice(20, self.k_shot + self.k_query, False) + + # meta-training and meta-test + x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]]) + x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]]) + y_spt.append([j for _ in range(self.k_shot)]) + y_qry.append([j for _ in range(self.k_query)]) + + # shuffle inside a batch + perm = self.rng.permutation(self.n_way * self.k_shot) + x_spt = np.array(x_spt).reshape( + self.n_way * self.k_shot, 1, self.resize, self.resize + )[perm] + y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm] + perm = self.rng.permutation(self.n_way * self.k_query) + x_qry = np.array(x_qry).reshape( + self.n_way * self.k_query, 1, self.resize, self.resize + )[perm] + y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm] + + # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84] + x_spts.append(x_spt) + y_spts.append(y_spt) + x_qrys.append(x_qry) + y_qrys.append(y_qry) + + # [b, setsz, 1, 84, 84] + x_spts = np.array(x_spts, dtype=np.float32).reshape( + self.batchsz, setsz, 1, self.resize, self.resize + ) + y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz) + # [b, qrysz, 1, 84, 84] + x_qrys = np.array(x_qrys, dtype=np.float32).reshape( + self.batchsz, querysz, 1, self.resize, self.resize + ) + y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz) + + x_spts, y_spts, x_qrys, y_qrys = [ + torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys] + ] + + data_cache.append([x_spts, y_spts, x_qrys, y_qrys]) + + return data_cache + + def next(self, mode='train'): + """ + Gets next batch from the dataset with name. + :param mode: The name of the splitting (one of "train", "val", "test") + :return: + """ + + # update cache if indexes is larger cached num + if self.indexes[mode] >= len(self.datasets_cache[mode]): + self.indexes[mode] = 0 + self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode]) + + next_batch = self.datasets_cache[mode][self.indexes[mode]] + self.indexes[mode] += 1 + + return next_batch