diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9e3934ec..e6c23138 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
+- Add implicit MAML omniglot few-shot classification example by [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#48](https://github.com/metaopt/torchopt/pull/48).
- Add object-oriented modules support for implicit meta-gradient by [@XuehaiPan](https://github.com/XuehaiPan) in [#101](https://github.com/metaopt/torchopt/pull/101).
- Bump PyTorch version to 1.13.0 by [@XuehaiPan](https://github.com/XuehaiPan) in [#104](https://github.com/metaopt/torchopt/pull/104).
- Add zero-order gradient estimation by [@JieRen98](https://github.com/JieRen98) in [#93](https://github.com/metaopt/torchopt/pull/93).
diff --git a/examples/iMAML/README.md b/examples/iMAML/README.md
new file mode 100644
index 00000000..91f95f69
--- /dev/null
+++ b/examples/iMAML/README.md
@@ -0,0 +1,18 @@
+# implicit MAML few-shot Omniglot classification-examples
+
+Code on implicit MAML few-shot Omniglot classification in paper [Meta-Learning with Implicit Gradients](https://arxiv.org/abs/1909.04630) using TorchOpt. We use `torchopt.sgd` as the inner-loop optimizer.
+
+## Usage
+
+```bash
+### Run
+python3 imaml_omniglot.py --inner_steps 5
+```
+
+## Results
+
+The figure illustrate the experimental result.
+
+
+
+
diff --git a/examples/iMAML/imaml-accs.png b/examples/iMAML/imaml-accs.png
new file mode 100644
index 00000000..1d5134a4
Binary files /dev/null and b/examples/iMAML/imaml-accs.png differ
diff --git a/examples/iMAML/imaml_omniglot.py b/examples/iMAML/imaml_omniglot.py
new file mode 100644
index 00000000..7b165ac0
--- /dev/null
+++ b/examples/iMAML/imaml_omniglot.py
@@ -0,0 +1,341 @@
+# Copyright 2022 MetaOPT Team. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""
+This example shows how to use TorchOpt to do iMAML-GD (see [1] for more details)
+for few-shot Omniglot classification.
+
+[1] Rajeswaran, A., Finn, C., Kakade, S. M., & Levine, S. (2019).
+ Meta-learning with implicit gradients. In Advances in Neural Information Processing Systems (pp. 113-124).
+ https://arxiv.org/abs/1909.04630
+"""
+
+import argparse
+import time
+
+import functorch
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+
+import torchopt
+from torchopt import pytree
+
+
+from support.omniglot_loaders import OmniglotNShot # isort: skip
+
+
+mpl.use('Agg')
+plt.style.use('bmh')
+
+
+def main():
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument('--n_way', type=int, help='n way', default=5)
+ argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=5)
+ argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=5)
+ argparser.add_argument('--inner_steps', type=int, help='number of inner steps', default=5)
+ argparser.add_argument(
+ '--reg_params', type=float, help='regularization parameters', default=2.0
+ )
+ argparser.add_argument(
+ '--task_num', type=int, help='meta batch size, namely task num', default=16
+ )
+ argparser.add_argument('--seed', type=int, help='random seed', default=1)
+ args = argparser.parse_args()
+
+ torch.manual_seed(args.seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(args.seed)
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+ np.random.seed(args.seed)
+ rng = np.random.default_rng(args.seed)
+
+ # Set up the Omniglot loader.
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
+ db = OmniglotNShot(
+ '/tmp/omniglot-data',
+ batchsz=args.task_num,
+ n_way=args.n_way,
+ k_shot=args.k_spt,
+ k_query=args.k_qry,
+ imgsz=28,
+ rng=rng,
+ device=device,
+ )
+
+ # Create a vanilla PyTorch neural network.
+ net = nn.Sequential(
+ nn.Conv2d(1, 64, 3),
+ nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False),
+ nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2),
+ nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False),
+ nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2),
+ nn.Conv2d(64, 64, 3),
+ nn.BatchNorm2d(64, momentum=1.0, affine=True, track_running_stats=False),
+ nn.ReLU(inplace=False),
+ nn.MaxPool2d(2, 2),
+ nn.Flatten(),
+ nn.Linear(64, args.n_way),
+ ).to(device)
+
+ # We will use Adam to (meta-)optimize the initial parameters
+ # to be adapted.
+ net.train()
+ fnet, params = functorch.make_functional(net)
+ meta_opt = torchopt.adam(lr=1e-3)
+ meta_opt_state = meta_opt.init(params)
+
+ log = []
+ test(db, [params, fnet], epoch=-1, log=log, args=args)
+ for epoch in range(10):
+ meta_opt, meta_opt_state = train(
+ db, [params, fnet], (meta_opt, meta_opt_state), epoch, log, args
+ )
+ test(db, [params, fnet], epoch, log, args)
+ plot(log)
+
+
+def train(db, net, meta_opt_and_state, epoch, log, args):
+ n_train_iter = db.x_train.shape[0] // db.batchsz
+ params, fnet = net
+ meta_opt, meta_opt_state = meta_opt_and_state
+ # Given this module we've created, rip out the parameters and buffers
+ # and return a functional version of the module. `fnet` is stateless
+ # and can be called with `fnet(params, buffers, args, kwargs)`
+ # fnet, params, buffers = functorch.make_functional_with_buffers(net)
+
+ for batch_idx in range(n_train_iter):
+ start_time = time.time()
+ # Sample a batch of support and query images and labels.
+ x_spt, y_spt, x_qry, y_qry = db.next()
+
+ task_num, setsz, c_, h, w = x_spt.size()
+ querysz = x_qry.size(1)
+
+ n_inner_iter = args.inner_steps
+ reg_param = args.reg_params
+ qry_losses = []
+ qry_accs = []
+
+ init_params_copy = pytree.tree_map(
+ lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
+ )
+
+ for i in range(task_num):
+ # Optimize the likelihood of the support set by taking
+ # gradient steps w.r.t. the model's parameters.
+ # This adapts the model's meta-parameters to the task.
+
+ optimal_params = train_imaml_inner_solver(
+ init_params_copy,
+ params,
+ (x_spt[i], y_spt[i]),
+ (fnet, n_inner_iter, reg_param),
+ )
+ # The final set of adapted parameters will induce some
+ # final loss and accuracy on the query dataset.
+ # These will be used to update the model's meta-parameters.
+ qry_logits = fnet(optimal_params, x_qry[i])
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i])
+ # Update the model's meta-parameters to optimize the query
+ # losses across all of the tasks sampled in this batch.
+ # qry_loss = qry_loss / task_num # scale gradients
+ meta_grads = torch.autograd.grad(qry_loss / task_num, params)
+ meta_updates, meta_opt_state = meta_opt.update(meta_grads, meta_opt_state)
+ params = torchopt.apply_updates(params, meta_updates)
+ qry_acc = (qry_logits.argmax(dim=1) == y_qry[i]).sum().item() / querysz
+
+ qry_losses.append(qry_loss.detach())
+ qry_accs.append(qry_acc)
+
+ qry_losses = sum(qry_losses) / task_num
+ qry_accs = 100.0 * sum(qry_accs) / task_num
+ i = epoch + float(batch_idx) / n_train_iter
+ iter_time = time.time() - start_time
+
+ print(
+ f'[Epoch {i:.2f}] Train Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f} | Time: {iter_time:.2f}'
+ )
+
+ log.append(
+ {
+ 'epoch': i,
+ 'loss': qry_losses,
+ 'acc': qry_accs,
+ 'mode': 'train',
+ 'time': time.time(),
+ }
+ )
+
+ return (meta_opt, meta_opt_state)
+
+
+def test(db, net, epoch, log, args):
+ # Crucially in our testing procedure here, we do *not* fine-tune
+ # the model during testing for simplicity.
+ # Most research papers using MAML for this task do an extra
+ # stage of fine-tuning here that should be added if you are
+ # adapting this code for research.
+ params, fnet = net
+ # fnet, params, buffers = functorch.make_functional_with_buffers(net)
+ n_test_iter = db.x_test.shape[0] // db.batchsz
+
+ qry_losses = []
+ qry_accs = []
+
+ # TODO: Maybe pull this out into a separate module so it
+ # doesn't have to be duplicated between `train` and `test`?
+ n_inner_iter = args.inner_steps
+ reg_param = args.reg_params
+ init_params_copy = pytree.tree_map(
+ lambda t: t.clone().detach_().requires_grad_(requires_grad=t.requires_grad), params
+ )
+
+ for batch_idx in range(n_test_iter):
+ x_spt, y_spt, x_qry, y_qry = db.next('test')
+
+ task_num, setsz, c_, h, w = x_spt.size()
+ querysz = x_qry.size(1)
+
+ for i in range(task_num):
+ # Optimize the likelihood of the support set by taking
+ # gradient steps w.r.t. the model's parameters.
+ # This adapts the model's meta-parameters to the task.
+
+ optimal_params = test_imaml_inner_solver(
+ init_params_copy,
+ params,
+ (x_spt[i], y_spt[i]),
+ (fnet, n_inner_iter, reg_param),
+ )
+
+ # The query loss and acc induced by these parameters.
+ qry_logits = fnet(optimal_params, x_qry[i])
+ qry_loss = F.cross_entropy(qry_logits, y_qry[i], reduction='none')
+ qry_losses.append(qry_loss.detach())
+ qry_accs.append((qry_logits.argmax(dim=1) == y_qry[i]).detach())
+
+ qry_losses = torch.cat(qry_losses).mean().item()
+ qry_accs = 100.0 * torch.cat(qry_accs).float().mean().item()
+ print(f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}')
+ log.append(
+ {
+ 'epoch': epoch + 1,
+ 'loss': qry_losses,
+ 'acc': qry_accs,
+ 'mode': 'test',
+ 'time': time.time(),
+ }
+ )
+
+
+def imaml_objective(optimal_params, init_params, data, aux):
+ x_spt, y_spt = data
+ fnet, n_inner_iter, reg_param = aux
+ fnet.eval()
+ y_pred = fnet(optimal_params, x_spt)
+ fnet.train()
+ regularization_loss = 0
+ for p1, p2 in zip(optimal_params, init_params):
+ regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
+ loss = F.cross_entropy(y_pred, y_spt) + regularization_loss
+ return loss
+
+
+@torchopt.diff.implicit.custom_root(
+ functorch.grad(imaml_objective, argnums=0),
+ argnums=1,
+ has_aux=False,
+ solve=torchopt.linear_solve.solve_normal_cg(maxiter=5, atol=0),
+)
+def train_imaml_inner_solver(init_params_copy, init_params, data, aux):
+ x_spt, y_spt = data
+ fnet, n_inner_iter, reg_param = aux
+ # Initial functional optimizer based on TorchOpt
+ params = init_params_copy
+ inner_opt = torchopt.sgd(lr=1e-1)
+ inner_opt_state = inner_opt.init(params)
+ with torch.enable_grad():
+ # Temporarily enable gradient computation for conducting the optimization
+ for _ in range(n_inner_iter):
+ pred = fnet(params, x_spt)
+ loss = F.cross_entropy(pred, y_spt) # compute loss
+ # Compute regularization loss
+ regularization_loss = 0
+ for p1, p2 in zip(params, init_params):
+ regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
+ final_loss = loss + regularization_loss
+ grads = torch.autograd.grad(final_loss, params) # compute gradients
+ updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
+ params = torchopt.apply_updates(params, updates)
+ return params
+
+
+def test_imaml_inner_solver(init_params_copy, init_params, data, aux):
+ x_spt, y_spt = data
+ fnet, n_inner_iter, reg_param = aux
+ # Initial functional optimizer based on TorchOpt
+ params = init_params_copy
+ inner_opt = torchopt.sgd(lr=1e-1)
+ inner_opt_state = inner_opt.init(params)
+ with torch.enable_grad():
+ # Temporarily enable gradient computation for conducting the optimization
+ for _ in range(n_inner_iter):
+ pred = fnet(params, x_spt)
+ loss = F.cross_entropy(pred, y_spt) # compute loss
+ # Compute regularization loss
+ regularization_loss = 0
+ for p1, p2 in zip(params, init_params):
+ regularization_loss += 0.5 * reg_param * torch.sum(torch.square(p1 - p2))
+ final_loss = loss + regularization_loss
+ grads = torch.autograd.grad(final_loss, params) # compute gradients
+ updates, inner_opt_state = inner_opt.update(grads, inner_opt_state) # get updates
+ params = torchopt.apply_updates(params, updates)
+ return params
+
+
+def plot(log):
+ # Generally you should pull your plotting code out of your training
+ # script but we are doing it here for brevity.
+ df = pd.DataFrame(log)
+
+ fig, ax = plt.subplots(figsize=(8, 4), dpi=250)
+ train_df = df[df['mode'] == 'train']
+ test_df = df[df['mode'] == 'test']
+ ax.plot(train_df['epoch'], train_df['acc'], label='Train')
+ ax.plot(test_df['epoch'], test_df['acc'], label='Test')
+ ax.set_xlabel('Epoch')
+ ax.set_ylabel('Accuracy')
+ ax.set_ylim(80, 100)
+ ax.set_title('iMAML Omniglot')
+ ax.legend(ncol=2, loc='lower right')
+ fig.tight_layout()
+ fname = 'imaml-accs.png'
+ print(f'--- Plotting accuracy to {fname}')
+ fig.savefig(fname)
+ plt.close(fig)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/iMAML/support/omniglot_loaders.py b/examples/iMAML/support/omniglot_loaders.py
new file mode 100644
index 00000000..d857d386
--- /dev/null
+++ b/examples/iMAML/support/omniglot_loaders.py
@@ -0,0 +1,327 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+# These Omniglot loaders are from Jackie Loong's PyTorch MAML implementation:
+# https://github.com/dragen1860/MAML-Pytorch
+# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglot.py
+# https://github.com/dragen1860/MAML-Pytorch/blob/master/omniglotNShot.py
+# ==============================================================================
+
+import errno
+import os
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torchvision.transforms as transforms
+from PIL import Image
+
+
+class Omniglot(data.Dataset):
+ """
+ The items are ``(filename, category)``. The index of all the categories can be found in
+ :attr:`idx_classes`.
+
+ Args:
+ root: the directory where the dataset will be stored
+ transform: how to transform the input
+ target_transform: how to transform the target
+ download: need to download the dataset
+ """
+
+ urls = [
+ 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
+ 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip',
+ ]
+ raw_folder = 'raw'
+ processed_folder = 'processed'
+ training_file = 'training.pt'
+ test_file = 'test.pt'
+
+ def __init__(self, root, transform=None, target_transform=None, download=False):
+ self.root = root
+ self.transform = transform
+ self.target_transform = target_transform
+
+ if not self._check_exists():
+ if download:
+ self.download()
+ else:
+ raise RuntimeError('Dataset not found. You can use download=True to download it')
+
+ self.all_items = find_classes(os.path.join(self.root, self.processed_folder))
+ self.idx_classes = index_classes(self.all_items)
+
+ def __getitem__(self, index):
+ filename = self.all_items[index][0]
+ img = str.join('/', [self.all_items[index][2], filename])
+
+ target = self.idx_classes[self.all_items[index][1]]
+ if self.transform is not None:
+ img = self.transform(img)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ return img, target
+
+ def __len__(self):
+ return len(self.all_items)
+
+ def _check_exists(self):
+ return os.path.exists(
+ os.path.join(self.root, self.processed_folder, 'images_evaluation')
+ ) and os.path.exists(os.path.join(self.root, self.processed_folder, 'images_background'))
+
+ def download(self):
+ import zipfile
+
+ from six.moves import urllib
+
+ if self._check_exists():
+ return
+
+ # download files
+ try:
+ os.makedirs(os.path.join(self.root, self.raw_folder))
+ os.makedirs(os.path.join(self.root, self.processed_folder))
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ pass
+ else:
+ raise
+
+ for url in self.urls:
+ print('== Downloading ' + url)
+ data = urllib.request.urlopen(url)
+ filename = url.rpartition('/')[2]
+ file_path = os.path.join(self.root, self.raw_folder, filename)
+ with open(file_path, 'wb') as f:
+ f.write(data.read())
+ file_processed = os.path.join(self.root, self.processed_folder)
+ print('== Unzip from ' + file_path + ' to ' + file_processed)
+ zip_ref = zipfile.ZipFile(file_path, 'r')
+ zip_ref.extractall(file_processed)
+ zip_ref.close()
+ print('Download finished.')
+
+
+def find_classes(root_dir):
+ retour = []
+ for (root, dirs, files) in os.walk(root_dir):
+ for f in files:
+ if f.endswith('png'):
+ r = root.split('/')
+ lr = len(r)
+ retour.append((f, r[lr - 2] + '/' + r[lr - 1], root))
+ print('== Found %d items ' % len(retour))
+ return retour
+
+
+def index_classes(items):
+ idx = {}
+ for i in items:
+ if i[1] not in idx:
+ idx[i[1]] = len(idx)
+ print('== Found %d classes' % len(idx))
+ return idx
+
+
+class OmniglotNShot:
+ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, rng, device=None):
+ """
+ Different from mnistNShot, the
+ :param root:
+ :param batchsz: task num
+ :param n_way:
+ :param k_shot:
+ :param k_qry:
+ :param imgsz:
+ """
+
+ self.resize = imgsz
+ self.rng = rng
+ self.device = device
+ if not os.path.isfile(os.path.join(root, 'omniglot.npy')):
+ # if root/data.npy does not exist, just download it
+ self.x = Omniglot(
+ root,
+ download=True,
+ transform=transforms.Compose(
+ [
+ lambda x: Image.open(x).convert('L'),
+ lambda x: x.resize((imgsz, imgsz)),
+ lambda x: np.reshape(x, (imgsz, imgsz, 1)),
+ lambda x: np.transpose(x, [2, 0, 1]),
+ lambda x: x / 255.0,
+ ]
+ ),
+ )
+
+ # {label: [img1, img2..., img20], label2: [img1, img2, ...], ... 1623 labels in total}
+ temp = {}
+ for (img, label) in self.x:
+ if label in temp.keys():
+ temp[label].append(img)
+ else:
+ temp[label] = [img]
+
+ self.x = []
+ for (
+ label,
+ imgs,
+ ) in temp.items(): # labels info deserted , each label contains 20imgs
+ self.x.append(np.array(imgs))
+
+ # as different class may have different number of imgs
+ self.x = np.array(self.x).astype(np.float) # [[20 imgs],..., 1623 classes in total]
+ # each character contains 20 imgs
+ print('data shape:', self.x.shape) # [1623, 20, 84, 84, 1]
+ temp = [] # Free memory
+ # save all dataset into npy file.
+ np.save(os.path.join(root, 'omniglot.npy'), self.x)
+ print('write into omniglot.npy.')
+ else:
+ # if data.npy exists, just load it.
+ self.x = np.load(os.path.join(root, 'omniglot.npy'))
+ print('load from omniglot.npy.')
+
+ # [1623, 20, 84, 84, 1]
+ # TODO: can not shuffle here, we must keep training and test set distinct!
+ self.x_train, self.x_test = self.x[:1200], self.x[1200:]
+
+ # self.normalization()
+
+ self.batchsz = batchsz
+ self.n_cls = self.x.shape[0] # 1623
+ self.n_way = n_way # n way
+ self.k_shot = k_shot # k shot
+ self.k_query = k_query # k query
+ assert (k_shot + k_query) <= 20
+
+ # save pointer of current read batch in total cache
+ self.indexes = {'train': 0, 'test': 0}
+ self.datasets = {
+ 'train': self.x_train,
+ 'test': self.x_test,
+ } # original data cached
+ print('DB: train', self.x_train.shape, 'test', self.x_test.shape)
+
+ self.datasets_cache = {
+ 'train': self.load_data_cache(self.datasets['train']), # current epoch data cached
+ 'test': self.load_data_cache(self.datasets['test']),
+ }
+
+ def normalization(self):
+ """
+ Normalizes our data, to have a mean of 0 and sdt of 1
+ """
+ self.mean = np.mean(self.x_train)
+ self.std = np.std(self.x_train)
+ self.max = np.max(self.x_train)
+ self.min = np.min(self.x_train)
+ # print("before norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
+ self.x_train = (self.x_train - self.mean) / self.std
+ self.x_test = (self.x_test - self.mean) / self.std
+
+ self.mean = np.mean(self.x_train)
+ self.std = np.std(self.x_train)
+ self.max = np.max(self.x_train)
+ self.min = np.min(self.x_train)
+
+ # print("after norm:", "mean", self.mean, "max", self.max, "min", self.min, "std", self.std)
+
+ def load_data_cache(self, data_pack):
+ """
+ Collects several batches data for N-shot learning
+ :param data_pack: [cls_num, 20, 84, 84, 1]
+ :return: A list with [support_set_x, support_set_y, target_x, target_y] ready to be fed to our networks
+ """
+
+ # take 5 way 1 shot as example: 5 * 1
+ setsz = self.k_shot * self.n_way
+ querysz = self.k_query * self.n_way
+ data_cache = []
+
+ # print('preload next 50 caches of batchsz of batch.')
+ for sample in range(10): # num of episodes
+
+ x_spts, y_spts, x_qrys, y_qrys = [], [], [], []
+ for i in range(self.batchsz): # one batch means one set
+
+ x_spt, y_spt, x_qry, y_qry = [], [], [], []
+ selected_cls = self.rng.choice(data_pack.shape[0], self.n_way, False)
+
+ for j, cur_class in enumerate(selected_cls):
+
+ selected_img = self.rng.choice(20, self.k_shot + self.k_query, False)
+
+ # meta-training and meta-test
+ x_spt.append(data_pack[cur_class][selected_img[: self.k_shot]])
+ x_qry.append(data_pack[cur_class][selected_img[self.k_shot :]])
+ y_spt.append([j for _ in range(self.k_shot)])
+ y_qry.append([j for _ in range(self.k_query)])
+
+ # shuffle inside a batch
+ perm = self.rng.permutation(self.n_way * self.k_shot)
+ x_spt = np.array(x_spt).reshape(
+ self.n_way * self.k_shot, 1, self.resize, self.resize
+ )[perm]
+ y_spt = np.array(y_spt).reshape(self.n_way * self.k_shot)[perm]
+ perm = self.rng.permutation(self.n_way * self.k_query)
+ x_qry = np.array(x_qry).reshape(
+ self.n_way * self.k_query, 1, self.resize, self.resize
+ )[perm]
+ y_qry = np.array(y_qry).reshape(self.n_way * self.k_query)[perm]
+
+ # append [sptsz, 1, 84, 84] => [b, setsz, 1, 84, 84]
+ x_spts.append(x_spt)
+ y_spts.append(y_spt)
+ x_qrys.append(x_qry)
+ y_qrys.append(y_qry)
+
+ # [b, setsz, 1, 84, 84]
+ x_spts = np.array(x_spts, dtype=np.float32).reshape(
+ self.batchsz, setsz, 1, self.resize, self.resize
+ )
+ y_spts = np.array(y_spts, dtype=np.int).reshape(self.batchsz, setsz)
+ # [b, qrysz, 1, 84, 84]
+ x_qrys = np.array(x_qrys, dtype=np.float32).reshape(
+ self.batchsz, querysz, 1, self.resize, self.resize
+ )
+ y_qrys = np.array(y_qrys, dtype=np.int).reshape(self.batchsz, querysz)
+
+ x_spts, y_spts, x_qrys, y_qrys = [
+ torch.from_numpy(z).to(self.device) for z in [x_spts, y_spts, x_qrys, y_qrys]
+ ]
+
+ data_cache.append([x_spts, y_spts, x_qrys, y_qrys])
+
+ return data_cache
+
+ def next(self, mode='train'):
+ """
+ Gets next batch from the dataset with name.
+ :param mode: The name of the splitting (one of "train", "val", "test")
+ :return:
+ """
+
+ # update cache if indexes is larger cached num
+ if self.indexes[mode] >= len(self.datasets_cache[mode]):
+ self.indexes[mode] = 0
+ self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
+
+ next_batch = self.datasets_cache[mode][self.indexes[mode]]
+ self.indexes[mode] += 1
+
+ return next_batch