import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torch.backends.cudnn as cudnn import torchvision import torchvision.transforms as transforms import os import os, os.path, sys import argparse import importlib import importlib.abc #from models import * learning_rate = 0.1 epsilon = 0.0314 k = 7 alpha = 0.00784 file_name = 'pgd_adversarial_training' device = 'cuda' if torch.cuda.is_available() else 'cpu' transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) transform_test = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4) class LinfPGDAttack(object): def __init__(self, model): self.model = model def perturb(self, x_natural, y): x = x_natural.detach() x = x + torch.zeros_like(x).uniform_(-epsilon, epsilon) for i in range(k): x.requires_grad_() with torch.enable_grad(): logits = self.model(x) loss = F.cross_entropy(logits, y) grad = torch.autograd.grad(loss, [x])[0] x = x.detach() + alpha * torch.sign(grad.detach()) x = torch.min(torch.max(x, x_natural - epsilon), x_natural + epsilon) x = torch.clamp(x, 0, 1) return x def attack(x, y, model, adversary): model_copied = copy.deepcopy(model) model_copied.eval() adversary.model = model_copied adv = adversary.perturb(x, y) return adv def load_project(project_dir): module_filename = os.path.join(project_dir, 'model.py') if os.path.exists(project_dir) and os.path.isdir(project_dir) and os.path.isfile(module_filename): print("Found valid project in '{}'.".format(project_dir)) else: print("Fatal: '{}' is not a valid project directory.".format(project_dir)) raise FileNotFoundError sys.path = [project_dir] + sys.path spec = importlib.util.spec_from_file_location("model", module_filename) project_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(project_module) return project_module #net = ResNet18() torch.seed() parser = argparse.ArgumentParser() parser.add_argument("project_dir", metavar="project-dir", nargs="?", default=os.getcwd(), help="Path to the project directory to test.") parser.add_argument("-b", "--batch-size", type=int, default=256, help="Set batch size.") parser.add_argument("-s", "--num-samples", type=int, default=1, help="Num samples for testing (required to test randomized networks).") def train(epoch): print('\n[ Train epoch: %d ]' % epoch) net.train() train_loss = 0 correct = 0 total = 0 for batch_idx, (inputs, targets) in enumerate(train_loader): inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() adv = adversary.perturb(inputs, targets) adv_outputs = net(adv) loss = criterion(adv_outputs, targets) loss.backward() optimizer.step() train_loss += loss.item() _, predicted = adv_outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() if batch_idx % 100 == 0: print('\nCurrent batch:', str(batch_idx)) print('Current adversarial train accuracy:', str(predicted.eq(targets).sum().item() / targets.size(0))) print('Current adversarial train loss:', loss.item()) print('\nTotal adversarial train accuarcy:', 100. * correct / total) print('Total adversarial train loss:', train_loss) def test(epoch): print('\n[ Test epoch: %d ]' % epoch) net.eval() benign_loss = 0 adv_loss = 0 benign_correct = 0 adv_correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(test_loader): inputs, targets = inputs.to(device), targets.to(device) total += targets.size(0) outputs = net(inputs) loss = criterion(outputs, targets) benign_loss += loss.item() _, predicted = outputs.max(1) benign_correct += predicted.eq(targets).sum().item() if batch_idx % 10 == 0: print('\nCurrent batch:', str(batch_idx)) print('Current benign test accuracy:', str(predicted.eq(targets).sum().item() / targets.size(0))) print('Current benign test loss:', loss.item()) adv = adversary.perturb(inputs, targets) adv_outputs = net(adv) loss = criterion(adv_outputs, targets) adv_loss += loss.item() _, predicted = adv_outputs.max(1) adv_correct += predicted.eq(targets).sum().item() if batch_idx % 10 == 0: print('Current adversarial test accuracy:', str(predicted.eq(targets).sum().item() / targets.size(0))) print('Current adversarial test loss:', loss.item()) print('\nTotal benign test accuarcy:', 100. * benign_correct / total) print('Total adversarial test Accuarcy:', 100. * adv_correct / total) print('Total benign test loss:', benign_loss) print('Total adversarial test loss:', adv_loss) state = { 'net': net.state_dict() } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/' + file_name) print('Model Saved!') def adjust_learning_rate(optimizer, epoch): lr = learning_rate if epoch >= 100: lr /= 10 if epoch >= 150: lr /= 10 for param_group in optimizer.param_groups: param_group['lr'] = lr if __name__ == "__main__": args = parser.parse_args() project_module = load_project(args.project_dir) net = project_module.Net() net = net.to(device) net.load_for_testing(project_dir=args.project_dir) adversary = LinfPGDAttack(net) criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0002) for epoch in range(0, 200): adjust_learning_rate(optimizer, epoch) train(epoch) test(epoch)