Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADD new test #908

Merged
merged 6 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 64 additions & 101 deletions avalanche/benchmarks/utils/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,14 @@ def __len__(self):
class ReplayDataLoader:
"""Custom data loader for rehearsal/replay strategies."""

def __init__(
self,
data: AvalancheDataset,
memory: AvalancheDataset = None,
oversample_small_tasks: bool = False,
collate_mbatches=_default_collate_mbatches_fn,
batch_size: int = 32,
force_data_batch_size: int = None,
**kwargs
):
"""Custom data loader for rehearsal strategies.
def __init__(self, data: AvalancheDataset, memory: AvalancheDataset = None,
oversample_small_tasks: bool = False,
collate_mbatches=_default_collate_mbatches_fn,
batch_size: int = 32,
batch_size_mem: int = 32,
task_balanced_dataloader: bool = False,
**kwargs):
""" Custom data loader for rehearsal strategies.

The iterates in parallel two datasets, the current `data` and the
rehearsal `memory`, which are used to create mini-batches by
Expand All @@ -272,11 +269,14 @@ def __init__(
:param collate_mbatches: function that given a sequence of mini-batches
(one for each task) combines them into a single mini-batch. Used to
combine the mini-batches obtained separately from each task.
:param batch_size: the size of the batch. It must be greater than or
equal to the number of tasks.
:param force_data_batch_size: How many of the samples should be from the
current `data`. If None, it will equally divide each batch between
samples from all seen tasks in the current `data` and `memory`.
:param batch_size: the size of the data batch. It must be greater
than or equal to the number of tasks.
:param batch_size_mem: the size of the memory batch. If
`task_balanced_dataloader` is set to True, it must be greater than
or equal to the number of tasks.
:param task_balanced_dataloader: if true, buffer data loaders will be
task-balanced, otherwise it creates a single data loader for the
buffer samples.
:param kwargs: data loader arguments used to instantiate the loader for
each task separately. See pytorch :class:`DataLoader`.
"""
Expand All @@ -288,55 +288,32 @@ def __init__(
self.oversample_small_tasks = oversample_small_tasks
self.collate_mbatches = collate_mbatches

if force_data_batch_size is not None:
assert (
force_data_batch_size <= batch_size
), "Forced batch size of data must be <= entire batch size"

remaining_example_data = 0

mem_keys = len(self.memory.task_set)
mem_batch_size = batch_size - force_data_batch_size
mem_batch_size_k = mem_batch_size // mem_keys
remaining_example_mem = mem_batch_size % mem_keys
num_keys = len(self.memory.task_set)
if task_balanced_dataloader:
assert batch_size_mem >= num_keys, \
"Batch size must be greator or equal " \
"to the number of tasks in the memory " \
"and current data."

assert mem_batch_size >= mem_keys, (
"Batch size must be greator or equal "
"to the number of tasks in the memory."
)
# Create dataloader for data items
self.loader_data, _ = self._create_dataloaders(
data, batch_size, 0, False, **kwargs)

self.loader_data, _ = self._create_dataloaders(
data, force_data_batch_size, remaining_example_data, **kwargs
)
self.loader_memory, _ = self._create_dataloaders(
memory, mem_batch_size_k, remaining_example_mem, **kwargs
)
# Create dataloader for memory items
if task_balanced_dataloader:
single_group_batch_size = batch_size_mem // num_keys
remaining_example = batch_size_mem % num_keys
else:
num_keys = len(self.data.task_set) + len(self.memory.task_set)
assert batch_size >= num_keys, (
"Batch size must be greator or equal "
"to the number of tasks in the memory "
"and current data."
)
single_group_batch_size = batch_size_mem
remaining_example = 0

single_group_batch_size = batch_size // num_keys
remaining_example = batch_size % num_keys
self.loader_memory, remaining_example = self._create_dataloaders(
memory, single_group_batch_size, remaining_example,
task_balanced_dataloader, **kwargs)

self.loader_data, remaining_example = self._create_dataloaders(
data, single_group_batch_size, remaining_example, **kwargs
)
self.loader_memory, remaining_example = self._create_dataloaders(
memory, single_group_batch_size, remaining_example, **kwargs
)

self.max_len = max(
[
len(d)
for d in chain(
self.loader_data.values(), self.loader_memory.values()
)
]
)
self.max_len = max([len(d) for d in chain(
self.loader_data.values(), self.loader_memory.values())]
)

def __iter__(self):
iter_data_dataloaders = {}
Expand All @@ -347,33 +324,20 @@ def __iter__(self):
for t in self.loader_memory.keys():
iter_buffer_dataloaders[t] = iter(self.loader_memory[t])

max_len = max(
[
len(d)
for d in chain(
iter_data_dataloaders.values(),
iter_buffer_dataloaders.values(),
)
]
)
max_len = max([len(d) for d in iter_data_dataloaders.values()])

try:
for it in range(max_len):
mb_curr = []
self._get_mini_batch_from_data_dict(
self.data,
iter_data_dataloaders,
self.loader_data,
self.oversample_small_tasks,
mb_curr,
)
self.data, iter_data_dataloaders,
self.loader_data, False,
mb_curr)

self._get_mini_batch_from_data_dict(
self.memory,
iter_buffer_dataloaders,
self.loader_memory,
self.oversample_small_tasks,
mb_curr,
)
self.memory, iter_buffer_dataloaders,
self.loader_memory, self.oversample_small_tasks,
mb_curr)

yield self.collate_mbatches(mb_curr)
except StopIteration:
Expand All @@ -382,14 +346,9 @@ def __iter__(self):
def __len__(self):
return self.max_len

def _get_mini_batch_from_data_dict(
self,
data,
iter_dataloaders,
loaders_dict,
oversample_small_tasks,
mb_curr,
):
def _get_mini_batch_from_data_dict(self, data, iter_dataloaders,
loaders_dict, oversample_small_tasks,
mb_curr):
# list() is necessary because we may remove keys from the
# dictionary. This would break the generator.
for t in list(iter_dataloaders.keys()):
Expand All @@ -408,19 +367,23 @@ def _get_mini_batch_from_data_dict(
continue
mb_curr.append(tbatch)

def _create_dataloaders(
self, data_dict, single_exp_batch_size, remaining_example, **kwargs
):
def _create_dataloaders(self, data_dict, single_exp_batch_size,
remaining_example, task_balanced_dataloader,
**kwargs):
loaders_dict: Dict[int, DataLoader] = {}
for task_id in data_dict.task_set:
data = data_dict.task_set[task_id]
current_batch_size = single_exp_batch_size
if remaining_example > 0:
current_batch_size += 1
remaining_example -= 1
loaders_dict[task_id] = DataLoader(
data, batch_size=current_batch_size, **kwargs
)
if task_balanced_dataloader:
for task_id in data_dict.task_set:
data = data_dict.task_set[task_id]
current_batch_size = single_exp_batch_size
if remaining_example > 0:
current_batch_size += 1
remaining_example -= 1
loaders_dict[task_id] = DataLoader(
data, batch_size=current_batch_size, **kwargs)
else:
loaders_dict[0] = DataLoader(
data_dict, batch_size=single_exp_batch_size, **kwargs)

return loaders_dict, remaining_example


Expand Down
43 changes: 36 additions & 7 deletions avalanche/models/dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,36 @@ class MultiTaskModule(DynamicModule):
def __init__(self):
super().__init__()
self.known_train_tasks_labels = set()
self.max_class_label = 0
""" Set of task labels encountered up to now. """

def adaptation(self, dataset: AvalancheDataset = None):
"""Adapt the module (freeze units, add units...) using the current
data. Optimizers must be updated after the model adaptation.

Avalanche strategies call this method to adapt the architecture
*before* processing each experience. Strategies also update the
optimizer automatically.

.. warning::
As a general rule, you should NOT use this method to train the
model. The dataset should be used only to check conditions which
require the model's adaptation, such as the discovery of new
classes or tasks.

:param dataset: data from the current experience.
:return:
"""
self.max_class_label = max(self.max_class_label,
max(dataset.targets) + 1)
if self.training:
self.train_adaptation(dataset)
else:
self.eval_adaptation(dataset)

def eval_adaptation(self, dataset: AvalancheDataset):
pass

def train_adaptation(self, dataset: AvalancheDataset = None):
"""Update known task labels."""
task_labels = dataset.targets_task_labels
Expand Down Expand Up @@ -127,17 +155,16 @@ def forward(
else:
unique_tasks = torch.unique(task_labels)

out = None
out = torch.zeros(x.shape[0], self.max_class_label, device=x.device)
for task in unique_tasks:
task_mask = task_labels == task
x_task = x[task_mask]
out_task = self.forward_single_task(x_task, task.item())

if out is None:
out = torch.empty(
x.shape[0], *out_task.shape[1:], device=out_task.device
)
out[task_mask] = out_task
assert len(out_task.shape) == 2,\
"multi-head assumes mini-batches of 2 dimensions " \
"<batch, classes>"
n_labels_head = out_task.shape[1]
out[task_mask, :n_labels_head] = out_task
return out

def forward_single_task(
Expand Down Expand Up @@ -254,6 +281,8 @@ def __init__(self, in_features, initial_out_features=2):
self.in_features, self.starting_out_features
)
self.classifiers["0"] = first_head
self.max_class_label = max(self.max_class_label,
initial_out_features)

def adaptation(self, dataset: AvalancheDataset):
"""If `dataset` contains new tasks, a new head is initialized.
Expand Down
2 changes: 2 additions & 0 deletions avalanche/models/helper_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, model: nn.Module, classifier_name: str):
):
param.data = param_old

self.max_class_label = max(self.max_class_label,
out_size)
self._initialized = True

def forward_single_task(self, x: torch.Tensor, task_label: int):
Expand Down
59 changes: 36 additions & 23 deletions avalanche/training/plugins/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,46 @@ class ReplayPlugin(SupervisedPlugin):

The :mem_size: attribute controls the total number of patterns to be stored
in the external memory.

:param batch_size: the size of the data batch. If set to `None`, it
will be set equal to the strategy's batch size.
:param batch_size_mem: the size of the memory batch. If
`task_balanced_dataloader` is set to True, it must be greater than or
equal to the number of tasks. If its value is set to `None`
(the default value), it will be automatically set equal to the
data batch size.
:param task_balanced_dataloader: if True, buffer data loaders will be
task-balanced, otherwise it will create a single dataloader for the
buffer samples.
:param storage_policy: The policy that controls how to add new exemplars
in memory
:param force_data_batch_size: How many of the samples should be from the
current `data`. If None, it will equally divide each batch between
samples from all seen tasks in the current `data` and `memory`.
"""

def __init__(
self,
mem_size: int = 200,
storage_policy: Optional["ExemplarsBuffer"] = None,
force_data_batch_size: int = None,
):
def __init__(self, mem_size: int = 200, batch_size: int = None,
batch_size_mem: int = None,
task_balanced_dataloader: bool = False,
storage_policy: Optional["ExemplarsBuffer"] = None):
super().__init__()
self.mem_size = mem_size
self.force_data_batch_size = force_data_batch_size
self.batch_size = batch_size
self.batch_size_mem = batch_size_mem
self.task_balanced_dataloader = task_balanced_dataloader

if storage_policy is not None: # Use other storage policy
self.storage_policy = storage_policy
assert storage_policy.max_size == self.mem_size
else: # Default
self.storage_policy = ExperienceBalancedBuffer(
max_size=self.mem_size, adaptive_size=True
)
max_size=self.mem_size,
adaptive_size=True)

@property
def ext_mem(self):
return self.storage_policy.buffer_groups # a Dict<task_id, Dataset>

def before_training_exp(
self,
strategy: "SupervisedTemplate",
num_workers: int = 0,
shuffle: bool = True,
**kwargs
):
def before_training_exp(self, strategy: "SupervisedTemplate",
num_workers: int = 0, shuffle: bool = True,
**kwargs):
"""
Dataloader to build batches containing examples from both memories and
the training dataset
Expand All @@ -73,15 +77,24 @@ def before_training_exp(
# first experience. We don't use the buffer, no need to change
# the dataloader.
return

batch_size = self.batch_size
if batch_size is None:
batch_size = strategy.train_mb_size

batch_size_mem = self.batch_size_mem
if batch_size_mem is None:
batch_size_mem = strategy.train_mb_size

strategy.dataloader = ReplayDataLoader(
strategy.adapted_dataset,
self.storage_policy.buffer,
oversample_small_tasks=True,
batch_size=batch_size,
batch_size_mem=batch_size_mem,
task_balanced_dataloader=self.task_balanced_dataloader,
num_workers=num_workers,
batch_size=strategy.train_mb_size,
force_data_batch_size=self.force_data_batch_size,
shuffle=shuffle,
)
shuffle=shuffle)

def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs):
self.storage_policy.update(strategy, **kwargs)
Loading