Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Zero-shot Knowledge Transfer via Adversarial Belief Matching #241

Merged
merged 7 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions configs/distill/mmcls/dafl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ Learning portable neural networks is very essential for computer vision for the

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :----------------------------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone (pretrain) & logits (train) | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.11 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |

## Citation

Expand Down
41 changes: 41 additions & 0 deletions configs/distill/mmcls/zskt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Zero-shot Knowledge Transfer via Adversarial Belief Matching (ZSKT)

> [Zero-shot Knowledge Transfer via Adversarial Belief Matching](https://arxiv.org/abs/1905.09768)

<!-- [ALGORITHM] -->

## Abstract

Performing knowledge transfer from a large teacher network to a smaller student is a popular task in modern deep learning applications. However, due to growing dataset sizes and stricter privacy regulations, it is increasingly common not to have access to the data that was used to train the teacher. We propose a novel method which trains a student to match the predictions of its teacher without using any data or metadata. We achieve this by training an adversarial generator to search for images on which the student poorly matches the teacher, and then using them to train the student. Our resulting student closely approximates its teacher for simple datasets like SVHN, and on CIFAR10 we improve on the state-of-the-art for few-shot distillation (with 100 images per class), despite using no data. Finally, we also propose a metric to quantify the degree of belief matching between teacher and student in the vicinity of decision boundaries, and observe a significantly higher match between our zero-shot student and the teacher, than between a student distilled with real data and the teacher. Code available at: https://github.com/polo5/ZeroShotKnowledgeTransfer

## The teacher and student decision boundaries

![ZSKT_Distribution](/docs/en/imgs/model_zoo/zskt/zskt_distribution.png)

## Pseudo images sampled from the generator

![ZSKT_Fakeimgs](/docs/en/imgs/model_zoo/zskt/zskt_synthesis.png)

## Results and models

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :---------------: | :-----: | :-------------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-----------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------- |
| backbone & logits | Cifar10 | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb16_cifar10.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb16_cifar10.py) | 93.50 | 95.34 | 94.82 | [config](./dafl_logits_r34_r18_8xb256_cifar10.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_b16x8_cifar10_20210528-a8aa36a6.pth) \|[model](<>) \| [log](<>) |

## Citation

```latex
@article{micaelli2019zero,
title={Zero-shot knowledge transfer via adversarial belief matching},
author={Micaelli, Paul and Storkey, Amos J},
journal={Advances in Neural Information Processing Systems},
volume={32},
year={2019}
}
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add Acknowledgement: appreciate Davidgzx's contribution

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


## Acknowledgement

Appreciate Davidgzx's contribution.
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
_base_ = [
'mmcls::_base_/datasets/cifar10_bs16.py',
'mmcls::_base_/schedules/cifar10_bs128.py',
'mmcls::_base_/default_runtime.py'
]

model = dict(
_scope_='mmrazor',
type='DataFreeDistillation',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[125.307, 122.961, 113.8575],
std=[51.5865, 50.847, 51.255],
# convert image from BGR to RGB
bgr_to_rgb=False),
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb16_cifar10.py', pretrained=False),
teachers=dict(
r34=dict(
build_cfg=dict(
cfg_path='mmcls::resnet/resnet34_8xb16_cifar10.py',
pretrained=True),
ckpt_path='resnet34_b16x8_cifar10_20210528-a8aa36a6.pth')),
generator=dict(
type='ZSKTGenerator', img_size=32, latent_dim=256,
hidden_channels=128),
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
bb_s1=dict(type='ModuleOutputs', source='backbone.layer1.1.relu'),
bb_s2=dict(type='ModuleOutputs', source='backbone.layer2.1.relu'),
bb_s3=dict(type='ModuleOutputs', source='backbone.layer3.1.relu'),
bb_s4=dict(type='ModuleOutputs', source='backbone.layer4.1.relu'),
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
r34_bb_s1=dict(
type='ModuleOutputs', source='r34.backbone.layer1.2.relu'),
r34_bb_s2=dict(
type='ModuleOutputs', source='r34.backbone.layer2.3.relu'),
r34_bb_s3=dict(
type='ModuleOutputs', source='r34.backbone.layer3.5.relu'),
r34_bb_s4=dict(
type='ModuleOutputs', source='r34.backbone.layer4.2.relu'),
r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')),
distill_losses=dict(
loss_s1=dict(type='ATLoss', loss_weight=250.0),
loss_s2=dict(type='ATLoss', loss_weight=250.0),
loss_s3=dict(type='ATLoss', loss_weight=250.0),
loss_s4=dict(type='ATLoss', loss_weight=250.0),
loss_kl=dict(
type='KLDivergence', loss_weight=2.0, reduction='mean')),
loss_forward_mappings=dict(
loss_s1=dict(
s_feature=dict(
from_student=True, recorder='bb_s1', record_idx=1),
t_feature=dict(
from_student=False, recorder='r34_bb_s1', record_idx=1)),
loss_s2=dict(
s_feature=dict(
from_student=True, recorder='bb_s2', record_idx=1),
t_feature=dict(
from_student=False, recorder='r34_bb_s2', record_idx=1)),
loss_s3=dict(
s_feature=dict(
from_student=True, recorder='bb_s3', record_idx=1),
t_feature=dict(
from_student=False, recorder='r34_bb_s3', record_idx=1)),
loss_s4=dict(
s_feature=dict(
from_student=True, recorder='bb_s4', record_idx=1),
t_feature=dict(
from_student=False, recorder='r34_bb_s4', record_idx=1)),
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='r34_fc')))),
generator_distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
r34_fc=dict(type='ModuleOutputs', source='r34.head.fc')),
distill_losses=dict(
loss_kl=dict(
type='KLDivergence',
loss_weight=-2.0,
reduction='mean',
teacher_detach=False)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='r34_fc')))),
student_iter=10)

# model wrapper
model_wrapper_cfg = dict(
type='mmengine.MMSeparateDistributedDataParallel',
broadcast_buffers=False,
find_unused_parameters=True)

# optimizer wrapper
optim_wrapper = dict(
_delete_=True,
constructor='mmrazor.SeparateOptimWrapperConstructor',
architecture=dict(
optimizer=dict(type='SGD', lr=0.1, weight_decay=0.0005, momentum=0.9)),
generator=dict(optimizer=dict(type='Adam', lr=1e-3)))
auto_scale_lr = dict(base_batch_size=16)

iter_size = 50

param_scheduler = dict(
_delete_=True,
architecture=dict(
type='MultiStepLR',
milestones=[100 * iter_size, 200 * iter_size],
by_epoch=False),
generator=dict(
type='MultiStepLR',
milestones=[100 * iter_size, 200 * iter_size],
by_epoch=False))

train_cfg = dict(
_delete_=True, by_epoch=False, max_iters=500 * iter_size, val_interval=250)

train_dataloader = dict(
batch_size=16, sampler=dict(type='InfiniteSampler', shuffle=True))
val_dataloader = dict(batch_size=16)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

default_hooks = dict(
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=100, max_keep_ckpts=2))

log_processor = dict(by_epoch=False)
# Must set diff_rank_seed to True!
randomness = dict(seed=None, diff_rank_seed=True)
Binary file added docs/en/imgs/model_zoo/zskt/distribution.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/en/imgs/model_zoo/zskt/synthesis.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion mmrazor/models/architectures/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dafl_generator import DAFLGenerator
from .zskt_generator import ZSKTGenerator

__all__ = ['DAFLGenerator']
__all__ = ['DAFLGenerator', 'ZSKTGenerator']
4 changes: 2 additions & 2 deletions mmrazor/models/architectures/generators/dafl_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ def __init__(

def forward(self,
data: Optional[torch.Tensor] = None,
batch_size: int = 0) -> torch.Tensor:
batch_size: int = 1) -> torch.Tensor:
"""Forward function for generator.

Args:
data (torch.Tensor, optional): The input data. Defaults to None.
batch_size (int): Batch size. Defaults to 0.
batch_size (int): Batch size. Defaults to 1.
"""
batch_data = self.process_latent(data, batch_size)
img = self.linear(batch_data)
Expand Down
91 changes: 91 additions & 0 deletions mmrazor/models/architectures/generators/zskt_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn

from mmrazor.registry import MODELS
from .base_generator import BaseGenerator


class View(nn.Module):
"""Class for view tensors.

Args:
size (Tuple[int, ...]): Size of the output tensor.
"""

def __init__(self, size: Tuple[int, ...]) -> None:
super(View, self).__init__()
self.size = size

def forward(self, tensor: torch.Tensor) -> torch.Tensor:
""""Forward function for view tensors."""
return tensor.view(self.size)


@MODELS.register_module()
class ZSKTGenerator(BaseGenerator):
"""Generator for ZSKT. code link:
https://github.com/polo5/ZeroShotKnowledgeTransfer/

Args:
img_size (int): The size of generated image.
latent_dim (int): The dimension of latent data.
hidden_channels (int): The dimension of hidden channels.
scale_factor (int, optional): The scale factor for F.interpolate.
Defaults to 2.
leaky_slope (float, optional): The slope param in leaky relu. Defaults
to 0.2.
init_cfg (dict, optional): The config to control the initialization.
"""

def __init__(
self,
img_size: int,
latent_dim: int,
hidden_channels: int,
scale_factor: int = 2,
leaky_slope: float = 0.2,
init_cfg: Optional[Dict] = None,
) -> None:
super().__init__(
img_size, latent_dim, hidden_channels, init_cfg=init_cfg)
self.init_size = self.img_size // (scale_factor**2)
self.scale_factor = scale_factor

self.layers = nn.Sequential(
nn.Linear(self.latent_dim,
self.hidden_channels * self.init_size**2),
View((-1, self.hidden_channels, self.init_size, self.init_size)),
nn.BatchNorm2d(self.hidden_channels),
nn.Upsample(scale_factor=scale_factor),
nn.Conv2d(
self.hidden_channels,
self.hidden_channels,
3,
stride=1,
padding=1), nn.BatchNorm2d(self.hidden_channels),
nn.LeakyReLU(leaky_slope, inplace=True),
nn.Upsample(scale_factor=scale_factor),
nn.Conv2d(
self.hidden_channels,
self.hidden_channels // 2,
3,
stride=1,
padding=1), nn.BatchNorm2d(self.hidden_channels // 2),
nn.LeakyReLU(leaky_slope, inplace=True),
nn.Conv2d(self.hidden_channels // 2, 3, 3, stride=1, padding=1),
nn.BatchNorm2d(3, affine=True))

def forward(self,
data: Optional[torch.Tensor] = None,
batch_size: int = 1) -> torch.Tensor:
"""Forward function for generator.

Args:
data (torch.Tensor, optional): The input data. Defaults to None.
batch_size (int): Batch size. Defaults to 1.
"""
batch_data = self.process_latent(data, batch_size)
return self.layers(batch_data)
2 changes: 1 addition & 1 deletion mmrazor/models/architectures/heads/darts_subnet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import List, Tuple

import torch
from mmcls.data import ClsDataSample
from mmcls.evaluation import Accuracy
from mmcls.models.heads import LinearClsHead
from mmcls.structures import ClsDataSample
from torch import nn

from mmrazor.models.utils import add_prefix
Expand Down
2 changes: 1 addition & 1 deletion mmrazor/models/distillers/configurable_distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def _check_loss_forward_mappings(
{type(loss_module).__name__} forward, \
please check your config.'

if (loss_forward_params[forward_key].default ==
if (loss_forward_params[forward_key].default !=
loss_forward_params[forward_key].empty):
# default params without check
continue
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ab_loss import ABLoss
from .at_loss import ATLoss
from .cwd import ChannelWiseDivergence
from .dafl_loss import ActivationLoss, InformationEntropyLoss, OnehotLikeLoss
from .decoupled_kd import DKDLoss
Expand All @@ -13,5 +14,5 @@
__all__ = [
'ChannelWiseDivergence', 'KLDivergence', 'AngleWiseRKD', 'DistanceWiseRKD',
'WSLD', 'L2Loss', 'ABLoss', 'DKDLoss', 'KDSoftCELoss', 'ActivationLoss',
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss'
'OnehotLikeLoss', 'InformationEntropyLoss', 'FTLoss', 'ATLoss'
]
41 changes: 41 additions & 0 deletions mmrazor/models/losses/at_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmrazor.registry import MODELS


@MODELS.register_module()
class ATLoss(nn.Module):
""""Paying More Attention to Attention: Improving the Performance of
Convolutional Neural Networks via Attention Transfer" Conference paper at
ICLR2017 https://openreview.net/forum?id=Sks9_ajex.

https://github.com/szagoruyko/attention-transfer/blob/master/utils.py

Args:
loss_weight (float): Weight of loss. Defaults to 1.0.
"""

def __init__(
self,
loss_weight: float = 1.0,
) -> None:
super().__init__()
self.loss_weight = loss_weight

def forward(self, s_feature: torch.Tensor,
t_feature: torch.Tensor) -> torch.Tensor:
""""Forward function for ATLoss."""
loss = (self.calc_attention_matrix(s_feature) -
self.calc_attention_matrix(t_feature)).pow(2).mean()
return self.loss_weight * loss

def calc_attention_matrix(self, x: torch.Tensor) -> torch.Tensor:
""""Calculate the attention matrix.

Args:
x (torch.Tensor): Input features.
"""
return F.normalize(x.pow(2).mean(1).view(x.size(0), -1))
Loading