-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Model Compression] Add Unit Test #4125
Changes from 17 commits
fb06ab1
d202b67
fe2c09d
27917dd
97ffcf4
91f4208
1c13540
3d0bcad
bfb7dfb
3b2c655
11a4c2c
1961a54
b79b20c
041dd06
7b37ca8
d3df0b9
626b082
6effafa
6d4a46e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import functools | ||
from tqdm import tqdm | ||
|
||
import torch | ||
from torchvision import datasets, transforms | ||
|
||
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner | ||
from nni.algorithms.compression.v2.pytorch.pruning.tools import AGPTaskGenerator | ||
from nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler import PruningScheduler | ||
|
||
from examples.model_compress.models.cifar10.vgg import VGG | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomCrop(32, 4), | ||
transforms.ToTensor(), | ||
normalize, | ||
]), download=True), | ||
batch_size=128, shuffle=True) | ||
|
||
test_loader = torch.utils.data.DataLoader( | ||
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize, | ||
])), | ||
batch_size=128, shuffle=False) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
def trainer(model, optimizer, criterion, epoch): | ||
model.train() | ||
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = criterion(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
def finetuner(model): | ||
model.train() | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = criterion(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
def evaluator(model): | ||
model.eval() | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in tqdm(iterable=test_loader, desc='Test'): | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
acc = 100 * correct / len(test_loader.dataset) | ||
print('Accuracy: {}%\n'.format(acc)) | ||
return acc | ||
|
||
|
||
if __name__ == '__main__': | ||
model = VGG().to(device) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
# pre-train the model | ||
for i in range(5): | ||
trainer(model, optimizer, criterion, i) | ||
|
||
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] | ||
|
||
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model. | ||
# If you want to initialize pruner at first, you can use the follow code. | ||
|
||
# pruner = L1NormPruner(model, config_list) | ||
# pruner._unwrap_model() | ||
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermidiate_result=True) | ||
# pruner._wrap_model() | ||
|
||
# you can specify the log_dir, all intermidiate results and best result will save under this folder. | ||
# if you don't want to keep intermidiate results, you can set `keep_intermidiate_result=False`. | ||
task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermidiate_result=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'intermediate' or 'intermidiate'? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, thank you, I have this mistake everywhere TT, all places have been fixed. |
||
pruner = L1NormPruner(model, config_list) | ||
|
||
dummy_input = torch.rand(10, 3, 32, 32).to(device) | ||
|
||
# if you just want to keep the final result as the best result, you can pass evaluator as None. | ||
# or the result with the highest score (given by evaluator) will be the best result. | ||
|
||
# scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator) | ||
scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interface is much more complicated than the original There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is a good suggestion, I demo a high-level interface in #4236, we can discuss it in this pr. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @J-shang please schedule a meeting, let's discuss the user interface today |
||
|
||
scheduler.compress() | ||
zheng-ningxin marked this conversation as resolved.
Show resolved
Hide resolved
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from tqdm import tqdm | ||
|
||
import torch | ||
from torchvision import datasets, transforms | ||
|
||
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner | ||
from nni.compression.pytorch.speedup import ModelSpeedup | ||
|
||
from examples.model_compress.models.cifar10.vgg import VGG | ||
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomCrop(32, 4), | ||
transforms.ToTensor(), | ||
normalize, | ||
]), download=True), | ||
batch_size=128, shuffle=True) | ||
|
||
test_loader = torch.utils.data.DataLoader( | ||
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
normalize, | ||
])), | ||
batch_size=128, shuffle=False) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
def trainer(model, optimizer, criterion, epoch): | ||
model.train() | ||
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)): | ||
data, target = data.to(device), target.to(device) | ||
optimizer.zero_grad() | ||
output = model(data) | ||
loss = criterion(output, target) | ||
loss.backward() | ||
optimizer.step() | ||
|
||
def evaluator(model): | ||
model.eval() | ||
correct = 0 | ||
with torch.no_grad(): | ||
for data, target in tqdm(iterable=test_loader, desc='Test'): | ||
data, target = data.to(device), target.to(device) | ||
output = model(data) | ||
pred = output.argmax(dim=1, keepdim=True) | ||
correct += pred.eq(target.view_as(pred)).sum().item() | ||
acc = 100 * correct / len(test_loader.dataset) | ||
print('Accuracy: {}%\n'.format(acc)) | ||
return acc | ||
|
||
|
||
if __name__ == '__main__': | ||
model = VGG().to(device) | ||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
print('\nPre-train the model:') | ||
for i in range(5): | ||
trainer(model, optimizer, criterion, i) | ||
evaluator(model) | ||
|
||
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] | ||
pruner = L1NormPruner(model, config_list) | ||
_, masks = pruner.compress() | ||
|
||
print('\nThe accuracy with masks:') | ||
evaluator(model) | ||
|
||
pruner._unwrap_model() | ||
ModelSpeedup(model, dummy_input=torch.rand(10, 3, 32, 32).to(device), masks_file='simple_masks.pth').speedup_model() | ||
|
||
print('\nThe accuracy after speed up:') | ||
evaluator(model) | ||
|
||
print('\nFinetune the model after speed up:') | ||
for i in range(5): | ||
trainer(model, optimizer, criterion, i) | ||
evaluator(model) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,13 @@ | |
|
||
import collections | ||
import logging | ||
from typing import List, Dict, Optional, OrderedDict, Tuple, Any | ||
from typing import List, Dict, Optional, Tuple, Any | ||
|
||
import torch | ||
from torch.nn import Module | ||
|
||
from nni.common.graph_utils import TorchModuleGraph | ||
from nni.compression.pytorch.utils import get_module_by_name | ||
from nni.algorithms.compression.v2.pytorch.utils.pruning import get_module_by_name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Import path is too long~ Please remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thx, modify it~ |
||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
@@ -149,7 +149,7 @@ def _select_config(self, layer: LayerInfo) -> Optional[Dict]: | |
return None | ||
return ret | ||
|
||
def get_modules_wrapper(self) -> OrderedDict[str, Module]: | ||
def get_modules_wrapper(self) -> Dict[str, Module]: | ||
""" | ||
Returns | ||
------- | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .basic_pruner import * | ||
from .basic_scheduler import PruningScheduler | ||
from .tools import AGPTaskGenerator, LinearTaskGenerator, LotteryTicketTaskGenerator, SimulatedAnnealingTaskGenerator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If user 'unwrap_model()' and initialize pruner at first, should user
wrap_model()
again beforescheduler.compress()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, in fact, the
model
andconfig_list
pass to thepruner
won't be used at all. In the next update, I plan to support initialize pruner in this way for scheduler:Pruner(model=None, config_list=None, ...)