diff --git a/.gitignore b/.gitignore index 264e72a..02a437c 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,8 @@ verifier_log_* .idea *.so release +*.compiled +.DS_Store +*.csv +*.out +*.txt \ No newline at end of file diff --git a/README.md b/README.md index 0c87c95..e0fdbb4 100644 --- a/README.md +++ b/README.md @@ -12,10 +12,12 @@ ## What's New? +- New activation function (sin, cos, tan, GeLU) with optimizable bounds (α-CROWN) and [branch and bound support](https://files.sri.inf.ethz.ch/wfvml23/papers/paper_24.pdf) for non-ReLU activation functions. We achieve significant improvements on verifying neural networks with non-ReLU activation functions such as Transformer and LSTM networks. (09/2023) +- [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2023](https://sites.google.com/view/vnn2023). (08/2023) - Bound computation for higher-order computational graphs to support bounding Jacobian, Jacobian-vector products, and [local Lipschitz constants](https://arxiv.org/abs/2210.07394). (11/2022) -- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2022](https://sites.google.com/view/vnn2022). Our library supports the large CIFAR100, TinyImageNet and ImageNet models in VNN-COMP 2022. (09/2022) +- Our neural network verification tool [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) (using `auto_LiRPA` as its core library) **won** [VNN-COMP 2022](https://sites.google.com/view/vnn2022). Our library supports the large CIFAR100, TinyImageNet and ImageNet models in VNN-COMP 2022. (09/2022) - Implementation of **general cutting planes** ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf)), support of more activation functions and improved performance and scalability. (09/2022) -- Our neural network verification tool [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. (09/2021) +- Our neural network verification tool [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) **won** [VNN-COMP 2021](https://sites.google.com/view/vnn2021) **with the highest total score**, outperforming 11 SOTA verifiers. α,β-CROWN uses the `auto_LiRPA` library as its core bound computation library. (09/2021) - [Optimized CROWN/LiRPA](https://arxiv.org/pdf/2011.13824.pdf) bound (α-CROWN) for ReLU, **sigmoid**, **tanh**, and **maxpool** activation functions, which can significantly outperform regular CROWN bounds. See [simple_verification.py](examples/vision/simple_verification.py#L59) for an example. (07/31/2021) - Handle split constraints for ReLU neurons ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for complete verifiers. (07/31/2021) - A memory efficient GPU implementation of backward (CROWN) bounds for @@ -46,12 +48,12 @@ Our library supports the following algorithms: * Backward mode LiRPA bound propagation ([CROWN](https://arxiv.org/pdf/1811.00866.pdf)/[DeepPoly](https://files.sri.inf.ethz.ch/website/papers/DeepPoly.pdf)) * Backward mode LiRPA bound propagation with optimized bounds ([α-CROWN](https://arxiv.org/pdf/2011.13824.pdf)) -* Backward mode LiRPA bound propagation with split constraints ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) +* Backward mode LiRPA bound propagation with split constraints ([β-CROWN](https://arxiv.org/pdf/2103.06624.pdf)) for ReLU, and ([Shi et al. 2023](https://files.sri.inf.ethz.ch/wfvml23/papers/paper_24.pdf)) for general nonlinear functions * Generalized backward mode LiRPA bound propagation with general cutting plane constraints ([GCP-CROWN](https://arxiv.org/pdf/2208.05740.pdf)) * Forward mode LiRPA bound propagation ([Xu et al., 2020](https://arxiv.org/pdf/2002.12920)) * Forward mode LiRPA bound propagation with optimized bounds (similar to [α-CROWN](https://arxiv.org/pdf/2011.13824.pdf)) * Interval bound propagation ([IBP](https://arxiv.org/pdf/1810.12715.pdf)) -* Hybrid approaches, e.g., Forward+Backward, IBP+Backward ([CROWN-IBP](https://arxiv.org/pdf/1906.06316.pdf)), [α,β-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/huanzhang12/alpha-beta-CROWN.git)) +* Hybrid approaches, e.g., Forward+Backward, IBP+Backward ([CROWN-IBP](https://arxiv.org/pdf/1906.06316.pdf)), [α,β-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git) ([alpha-beta-CROWN](https://github.com/Verified-Intelligence/alpha-beta-CROWN.git)) Our library allows automatic bound derivation and computation for general computational graphs, in a similar manner that gradients are obtained in modern @@ -99,7 +101,7 @@ See [PyTorch Get Started](https://pytorch.org/get-started). Then you can install `auto_LiRPA` via: ```bash -git clone https://github.com/KaidiXu/auto_LiRPA +git clone https://github.com/Verified-Intelligence/auto_LiRPA cd auto_LiRPA python setup.py install ``` @@ -159,10 +161,10 @@ We provide [a wide range of examples](doc/src/examples.md) of using `auto_LiRPA` * [Certified Adversarial Defense Training on Sequence Data with **LSTM**](doc/src/examples.md#certified-adversarial-defense-training-for-lstm-on-mnist) * [Certifiably Robust Language Classifier using **Transformers**](doc/src/examples.md#certifiably-robust-language-classifier-with-transformer-and-lstm) * [Certified Robustness against **Model Weight Perturbations**](doc/src/examples.md#certified-robustness-against-model-weight-perturbations-and-certified-defense) -* [Bounding **Jacobian** and **local Lipschitz constants**](examples/vision/jacobian.py) +* [Bounding **Jacobian** and **local Lipschitz constants**](examples/vision/jacobian_new.py) `auto_LiRPA` has also be used in the following works: -* [**α,β-CROWN for complete neural network verification**](https://github.com/huanzhang12/alpha-beta-CROWN) +* [**α,β-CROWN for complete neural network verification**](https://github.com/Verified-Intelligence/alpha-beta-CROWN) * [**Fast certified robust training**](https://github.com/shizhouxing/Fast-Certified-Robust-Training) * [**Computing local Lipschitz constants**](https://github.com/shizhouxing/Local-Lipschitz-Constants) @@ -177,7 +179,7 @@ For more documentations, please refer to: ## Publications -Please kindly cite our papers if you use the `auto_LiRPA` library. Full [BibTeX entries](doc/examples.md#bibtex-entries) can be found [here](doc/examples.md#bibtex-entries). +Please kindly cite our papers if you use the `auto_LiRPA` library. Full [BibTeX entries](doc/src/examples.md#bibtex-entries) can be found [here](doc/src/examples.md#bibtex-entries). The general LiRPA based bound propagation algorithm was originally proposed in our paper: @@ -207,26 +209,30 @@ Certified robust training using `auto_LiRPA` is improved to allow much shorter w NeurIPS 2021. Zhouxing Shi\*, Yihan Wang\*, Huan Zhang, Jinfeng Yi and Cho-Jui Hsieh (\* Equal contribution). -## Developers and Copyright +Branch and bound for non-ReLU and general activation functions: +* [Formal Verification for Neural Networks with General Nonlinearities via Branch-and-Bound](https://files.sri.inf.ethz.ch/wfvml23/papers/paper_24.pdf). +Zhouxing Shi\*, Qirui Jin\*, Zico Kolter, Suman Jana, Cho-Jui Hsieh, Huan Zhang (\* Equal contribution). -| [Kaidi Xu](https://kaidixu.com/) | [Zhouxing Shi](https://shizhouxing.github.io/) | [Huan Zhang](https://huan-zhang.com/) | [Yihan Wang](https://yihanwang617.github.io/) | [Shiqi Wang](https://www.cs.columbia.edu/~tcwangshiqi/) | -|:--:|:--:| :--:| :--:| :--:| -| | | | | | +## Developers and Copyright Team lead: -* Huan Zhang (huan@huan-zhang.com), CMU +* Huan Zhang (huan@huan-zhang.com), UIUC -Main developers: +Current developers: * Zhouxing Shi (zshi@cs.ucla.edu), UCLA +* Linyi Li (linyi2@illinois.edu), UIUC +* Christopher Brix (brix@cs.rwth-aachen.de), RWTH Aachen University * Kaidi Xu (kx46@drexel.edu), Drexel University +* Xiangru Zhong (xiangruzh0915@gmail.com), Sun Yat-sen University +* Qirui Jin (qiruijin@umich.edu), University of Michigan +* Zhuolin Yang (zhuolin5@illinois.edu), UIUC +* Zhuowen Yuan (realzhuowen@gmail.com), UIUC -Contributors: -* Yihan Wang (yihanwang@ucla.edu), UCLA +Past developers: * Shiqi Wang (sw3215@columbia.edu), Columbia University -* Linyi Li (linyi2@illinois.edu), UIUC +* Yihan Wang (yihanwang@ucla.edu), UCLA * Jinqi (Kathryn) Chen (jinqic@cs.cmu.edu), CMU -* Zhuolin Yang (zhuolin5@illinois.edu), UIUC -We thank the [commits](https://github.com/KaidiXu/auto_LiRPA/commits) and [pull requests](https://github.com/KaidiXu/auto_LiRPA/pulls) from community contributors. +We thank the [commits](https://github.com/Verified-Intelligence/auto_LiRPA/commits) and [pull requests](https://github.com/Verified-Intelligence/auto_LiRPA/pulls) from community contributors. Our library is released under the BSD 3-Clause license. diff --git a/README_abcrown.md b/README_abcrown.md deleted file mode 100644 index 339f4ed..0000000 --- a/README_abcrown.md +++ /dev/null @@ -1,245 +0,0 @@ -α,β-CROWN (alpha-beta-CROWN): A Fast and Scalable Neural Network Verifier with Efficient Bound Propagation -====================== - -

- -

- -α,β-CROWN (alpha-beta-CROWN) is a neural network verifier based on an efficient -bound propagation algorithm ([CROWN](https://arxiv.org/pdf/1811.00866.pdf)) and -branch and bound. It can be accelerated efficiently on **GPUs** and can scale -to relatively large convolutional networks. It also supports a wide range of -neural network architectures (e.g., **CNN**, **ResNet**, and various activation -functions), thanks to the versatile -[auto\_LiRPA](http://github.com/KaidiXu/auto_LiRPA) library developed by us. -α,β-CROWN can provide **provable robustness guarantees against adversarial -attacks** and can also verify other general properties of neural networks. - -α,β-CROWN is the **winning verifier** in [VNN-COMP -2021](https://sites.google.com/view/vnn2021) and [VNN-COMP -2022](https://sites.google.com/view/vnn2022) (International Verification of -Neural Networks Competition) with the highest total score, outperforming many -other neural network verifiers on a wide range of benchmarks over 2 years. -Details of competition results can be found in [VNN-COMP 2021 -slides](https://docs.google.com/presentation/d/1oM3NqqU03EUqgQVc3bGK2ENgHa57u-W6Q63Vflkv000/edit#slide=id.ge4496ad360_14_21), -[report](https://arxiv.org/abs/2109.00498) and [VNN-COMP 2022 slides (page -73)](https://drive.google.com/file/d/1nnRWSq3plsPvOT3V-drAF5D8zWGu02VF/view?usp=sharing). - -Supported Features ----------------------- - -

- -

- -Our verifier consists of the following core algorithms: - -* **β-CROWN** ([Wang et al. 2021](https://arxiv.org/pdf/2103.06624.pdf)): complete verification with **CROWN** ([Zhang et al. 2018](https://arxiv.org/pdf/1811.00866.pdf)) and branch and bound -* **α-CROWN** ([Xu et al., 2021](https://arxiv.org/pdf/2011.13824.pdf)): incomplete verification with optimized CROWN bound -* **GCP-CROWN** ([Zhang et al. 2021](https://arxiv.org/pdf/2208.05740.pdf)): CROWN-like bound propagation with general cutting plane constraints. -* **BaB-Attack** ([Zhang et al. 2021](https://proceedings.mlr.press/v162/zhang22ae/zhang22ae.pdf)): Branch and bound based adversarial attack for tackling hard instances. -* **MIP** ([Tjeng et al., 2017](https://arxiv.org/pdf/1711.07356.pdf)): mixed integer programming (slow but can be useful on small models). - -We support these neural network architectures: - -* Layers: fully connected (FC), convolutional (CNN), pooling (average pool and max pool), transposed convolution -* Activation functions: ReLU (incomplete/complete verification); sigmoid, tanh, arctan, sin, cos, tan (incomplete verification) -* Residual connections and other irregular graphs - -We support the following verification specifications: - -* Lp norm perturbation (p=1,2,infinity, as often used in robustness verification) -* VNNLIB format input (at most two layers of AND/OR clause, as used in VNN-COMP 2021 and 2022) -* Any linear specifications on neural network output (which can be added as a linear layer) - -We provide many example configurations in -[`complete_verifier/exp_configs`](/complete_verifier/exp_configs) directory to -start with: - -* MNIST: MLP and CNN models -* CIFAR-10, CIFAR-100, TinyImageNet: CNN and ResNet models -* ACASXu, NN4sys and other low input-dimension models - -See the [Guide on Algorithm -Selection](/complete_verifier/docs/abcrown_usage.md#guide-on-algorithm-selection) -to find the most suitable example to get started. - -Installation and Setup ----------------------- - -α,β-CROWN is tested on Python 3.7+ and PyTorch 1.11. It can be installed -easily into a conda environment. If you don't have conda, you can install -[miniconda](https://docs.conda.io/en/latest/miniconda.html). - -```bash -# Remove the old environment, if necessary. -conda deactivate; conda env remove --name alpha-beta-crown -# install all dependents into the alpha-beta-crown environment -conda env create -f complete_verifier/environment.yml --name alpha-beta-crown -# activate the environment -conda activate alpha-beta-crown -``` - -If you use the α-CROWN and/or β-CROWN verifiers (which covers the most use -cases), a Gurobi license is *not needed*. If you want to use MIP based -verification algorithms (feasible only for small MLP models), you need to -install a Gurobi license with the `grbgetkey` command. If you don't have -access to a license, by default the above installation procedure includes a -free and restricted license, which is actually sufficient for many relatively -small NNs. If you use the GCP-CROWN verifier, an installation of IBM CPlex -solver is required. Instructions to install the CPlex solver can be found -in the [VNN-COMP benchmark instructions](/complete_verifier/docs/vnn_comp.md#installation) -or the [GCP-CROWN instructions](https://github.com/tcwangshiqi-columbia/GCP-CROWN). - -If you prefer to install packages manually rather than using a prepared conda -environment, you can refer to this [installation -script](/vnncomp_scripts/install_tool_general.sh). - -If you want to run α,β-CROWN verifier on the VNN-COMP 2021 and 2022 benchmarks -(e.g., to make a comparison to a new verifier), you can follow [this -guide](/complete_verifier/docs/vnn_comp.md). - -Instructions ----------------------- - -We provide a unified front-end for the verifier, `abcrown.py`. All parameters -for the verifier are defined in a `yaml` config file. For example, to run -robustness verification on a CIFAR-10 ResNet network, you just run: - -```bash -conda activate alpha-beta-crown # activate the conda environment -cd complete_verifier -python abcrown.py --config exp_configs/cifar_resnet_2b.yaml -``` - -You can find explanations for most useful parameters in [this example config -file](/complete_verifier/exp_configs/cifar_resnet_2b.yaml). For detailed usage -and tutorial examples please see the [Usage -Documentation](/complete_verifier/docs/abcrown_usage.md). We also provide a -large range of examples in the -[`complete_verifier/exp_configs`](/complete_verifier/exp_configs) folder. - - -Publications ----------------------- - -If you use our verifier in your work, **please kindly cite our CROWN**([Zhang -et al., 2018](https://arxiv.org/pdf/1811.00866.pdf)), **α-CROWN** ([Xu et al., -2021](https://arxiv.org/pdf/2011.13824.pdf)), **β-CROWN**([Wang et al., -2021](https://arxiv.org/pdf/2103.06624.pdf)) and **GCP-CROWN**([Zhang et al., -2022](https://arxiv.org/pdf/2208.05740.pdf)) papers. If your work involves the -convex relaxation of the NN verification please kindly cite [Salman et al., -2019](https://arxiv.org/pdf/1902.08722). If your work deals with -ResNet/DenseNet, LSTM (recurrent networks), Transformer or other complex -architectures, or model weight perturbations please kindly cite [Xu et al., -2020](https://arxiv.org/pdf/2002.12920.pdf). If you use our branch and bound -based adversarial attack (falsifier), please cite [Zhang et al. -2022](https://proceedings.mlr.press/v162/zhang22ae/zhang22ae.pdf). - -α,β-CROWN combines our existing efforts on neural network verification: - -* **CROWN** ([Zhang et al. NeurIPS 2018](https://arxiv.org/pdf/1811.00866.pdf)) is a very efficient bound propagation based verification algorithm. CROWN propagates a linear inequality backwards through the network and utilizes linear bounds to relax activation functions. - -* The **"convex relaxation barrier"** ([Salman et al., NeurIPS 2019](https://arxiv.org/pdf/1902.08722)) paper concludes that optimizing the ReLU relaxation allows CROWN (referred to as a "greedy" primal space solver) to achieve the same solution as linear programming (LP) based verifiers. - -* **LiRPA** ([Xu et al., NeurIPS 2020](https://arxiv.org/pdf/2002.12920.pdf)) is a generalization of CROWN on general computational graphs and we also provide an efficient GPU implementation, the [auto\_LiRPA](https://github.com/KaidiXu/auto_LiRPA) library. - -* **α-CROWN** (sometimes referred to as optimized CROWN or optimized LiRPA) is used in the Fast-and-Complete verifier ([Xu et al., ICLR 2021](https://arxiv.org/pdf/2011.13824.pdf)), which jointly optimizes intermediate layer bounds and final layer bounds in CROWN via variable α. α-CROWN typically has greater power than LP since LP cannot cheaply tighten intermediate layer bounds. - -* **β-CROWN** ([Wang et al., NeurIPS 2021](https://arxiv.org/pdf/2103.06624.pdf)) incorporates split constraints in branch and bound (BaB) into the CROWN bound propagation procedure via an additional optimizable parameter β. The combination of efficient and GPU accelerated bound propagation with branch and bound produces a powerful and scalable neural network verifier. - -* **BaB-Attack** ([Zhang et al., ICML 2022](https://proceedings.mlr.press/v162/zhang22ae/zhang22ae.pdf)) is a strong falsifier (adversarial attack) based on branch and bound, which can find adversarial examples for hard instances where gradient or input-space-search based methods cannot succeed. - -* **GCP-CROWN** ([Zhang et al., NeurIPS 2022](https://arxiv.org/pdf/2208.05740.pdf)) enables the use of general cutting planes methods for neural network verification in a GPU-accelerated and very efficient bound propagation framework. - -We provide bibtex entries below: - -``` -@article{zhang2018efficient, - title={Efficient Neural Network Robustness Certification with General Activation Functions}, - author={Zhang, Huan and Weng, Tsui-Wei and Chen, Pin-Yu and Hsieh, Cho-Jui and Daniel, Luca}, - journal={Advances in Neural Information Processing Systems}, - volume={31}, - pages={4939--4948}, - year={2018}, - url={https://arxiv.org/pdf/1811.00866.pdf} -} - -@article{xu2020automatic, - title={Automatic perturbation analysis for scalable certified robustness and beyond}, - author={Xu, Kaidi and Shi, Zhouxing and Zhang, Huan and Wang, Yihan and Chang, Kai-Wei and Huang, Minlie and Kailkhura, Bhavya and Lin, Xue and Hsieh, Cho-Jui}, - journal={Advances in Neural Information Processing Systems}, - volume={33}, - year={2020} -} - -@article{salman2019convex, - title={A Convex Relaxation Barrier to Tight Robustness Verification of Neural Networks}, - author={Salman, Hadi and Yang, Greg and Zhang, Huan and Hsieh, Cho-Jui and Zhang, Pengchuan}, - journal={Advances in Neural Information Processing Systems}, - volume={32}, - pages={9835--9846}, - year={2019} -} - -@inproceedings{xu2021fast, - title={{Fast and Complete}: Enabling Complete Neural Network Verification with Rapid and Massively Parallel Incomplete Verifiers}, - author={Kaidi Xu and Huan Zhang and Shiqi Wang and Yihan Wang and Suman Jana and Xue Lin and Cho-Jui Hsieh}, - booktitle={International Conference on Learning Representations}, - year={2021}, - url={https://openreview.net/forum?id=nVZtXBI6LNn} -} - -@article{wang2021beta, - title={{Beta-CROWN}: Efficient bound propagation with per-neuron split constraints for complete and incomplete neural network verification}, - author={Wang, Shiqi and Zhang, Huan and Xu, Kaidi and Lin, Xue and Jana, Suman and Hsieh, Cho-Jui and Kolter, J Zico}, - journal={Advances in Neural Information Processing Systems}, - volume={34}, - year={2021} -} - -@InProceedings{zhang22babattack, - title = {A Branch and Bound Framework for Stronger Adversarial Attacks of {R}e{LU} Networks}, - author = {Zhang, Huan and Wang, Shiqi and Xu, Kaidi and Wang, Yihan and Jana, Suman and Hsieh, Cho-Jui and Kolter, Zico}, - booktitle = {Proceedings of the 39th International Conference on Machine Learning}, - volume = {162}, - pages = {26591--26604}, - year = {2022}, -} - -@article{zhang2022general, - title={General Cutting Planes for Bound-Propagation-Based Neural Network Verification}, - author={Zhang, Huan and Wang, Shiqi and Xu, Kaidi and Li, Linyi and Li, Bo and Jana, Suman and Hsieh, Cho-Jui and Kolter, J Zico}, - journal={Advances in Neural Information Processing Systems}, - year={2022} -} -``` - -Developers and Copyright ----------------------- - -The α,β-CROWN verifier is developed by a team from CMU, UCLA, Drexel University, Columbia University and UIUC: - -Team lead: -* Huan Zhang (huan@huan-zhang.com), CMU - -Main developers: -* Kaidi Xu (kx46@drexel.edu), Drexel University -* Zhouxing Shi (zshi@cs.ucla.edu), UCLA -* Shiqi Wang (sw3215@columbia.edu), Columbia University - -Contributors: -* Linyi Li (linyi2@illinois.edu), UIUC -* Jinqi (Kathryn) Chen (jinqic@cs.cmu.edu), CMU -* Zhuolin Yang (zhuolin5@illinois.edu), UIUC -* Yihan Wang (yihanwang@ucla.edu), UCLA - -Advisors: -* Zico Kolter (zkolter@cs.cmu.edu), CMU -* Cho-Jui Hsieh (chohsieh@cs.ucla.edu), UCLA -* Suman Jana (suman@cs.columbia.edu), Columbia University -* Bo Li (lbo@illinois.edu), UIUC -* Xue Lin (xue.lin@northeastern.edu), Northeastern University - -Our library is released under the BSD 3-Clause license. A copy of the license is included [here](LICENSE). - diff --git a/auto_LiRPA/__init__.py b/auto_LiRPA/__init__.py index 42981d0..47ddb9a 100644 --- a/auto_LiRPA/__init__.py +++ b/auto_LiRPA/__init__.py @@ -5,4 +5,4 @@ from .wrapper import CrossEntropyWrapper, CrossEntropyWrapperMultiInput from .bound_op_map import register_custom_op, unregister_custom_op -__version__ = '0.3.1' +__version__ = '0.4.0' diff --git a/auto_LiRPA/backward_bound.py b/auto_LiRPA/backward_bound.py index 01d973f..b862761 100644 --- a/auto_LiRPA/backward_bound.py +++ b/auto_LiRPA/backward_bound.py @@ -1,23 +1,34 @@ +import os import torch from torch import Tensor -from collections import deque, defaultdict +from collections import deque from tqdm import tqdm from .patches import Patches from .utils import * from .bound_ops import * import warnings +from typing import TYPE_CHECKING, List +if TYPE_CHECKING: + from .bound_general import BoundedModule -def batched_backward( - self, node, C, unstable_idx, batch_size, bound_lower=True, - bound_upper=True): + +def batched_backward(self: 'BoundedModule', node, C, unstable_idx, batch_size, + bound_lower=True, bound_upper=True, return_A=None): + if return_A is None: return_A = self.return_A crown_batch_size = self.bound_opts['crown_batch_size'] - unstable_size = get_unstable_size(unstable_idx) - print(f'Batched CROWN: unstable size {unstable_size}') - num_batches = (unstable_size + crown_batch_size - 1) // crown_batch_size output_shape = node.output_shape[1:] dim = int(prod(output_shape)) + if unstable_idx is None: + unstable_idx = torch.arange(dim, device=self.device) + dense = True + else: + dense = False + unstable_size = get_unstable_size(unstable_idx) + print(f'Batched CROWN: node {node}, unstable size {unstable_size}') + num_batches = (unstable_size + crown_batch_size - 1) // crown_batch_size ret = [] + ret_A = {} # if return_A, we will store A here for i in tqdm(range(num_batches)): if isinstance(unstable_idx, tuple): unstable_idx_batch = tuple( @@ -34,7 +45,7 @@ def batched_backward( unstable_size_batch, batch_size, *node.output_shape[1:-2], 1, 1], identity=1, unstable_idx=unstable_idx_batch, output_shape=[batch_size, *node.output_shape[1:]]) - elif isinstance(node, BoundLinear) or isinstance(node, BoundMatMul): + elif isinstance(node, (BoundLinear, BoundMatMul)): assert C in ['OneHot', None] C_batch = OneHotC( [batch_size, unstable_size_batch, *node.output_shape[1:]], @@ -45,71 +56,154 @@ def batched_backward( C_batch[0, torch.arange(unstable_size_batch), unstable_idx_batch] = 1.0 C_batch = C_batch.expand(batch_size, -1, -1).view( batch_size, unstable_size_batch, *output_shape) - ret.append(self.backward_general( - C=C_batch, node=node, + # overwrite return_A options to run backward general + ori_return_A_option = self.return_A + self.return_A = return_A + + batch_ret = self.backward_general( + node, C_batch, bound_lower=bound_lower, bound_upper=bound_upper, average_A=False, need_A_only=False, unstable_idx=unstable_idx_batch, - verbose=False)) + verbose=False) + ret.append(batch_ret[:2]) + + if len(batch_ret) > 2: + # A found, we merge A + batch_A = batch_ret[2] + ret_A = merge_A(batch_A, ret_A) + + # restore return_A options + self.return_A = ori_return_A_option if bound_lower: lb = torch.cat([item[0].view(batch_size, -1) for item in ret], dim=1) + if dense: + # In this case, restore_sparse_bounds will not be called. + # And thus we restore the shape here. + lb = lb.reshape(batch_size, *output_shape) else: lb = None if bound_upper: ub = torch.cat([item[1].view(batch_size, -1) for item in ret], dim=1) + if dense: + # In this case, restore_sparse_bounds will not be called. + # And thus we restore the shape here. + ub = ub.reshape(batch_size, *output_shape) else: ub = None - return lb, ub + + if return_A: + return lb, ub, ret_A + else: + return lb, ub def backward_general( - self, C, node=None, bound_lower=True, bound_upper=True, average_A=False, - need_A_only=False, unstable_idx=None, unstable_size=0, update_mask=None, verbose=True): + self: 'BoundedModule', + bound_node, + C, + start_backpropagation_at_node = None, + bound_lower=True, + bound_upper=True, + average_A=False, + need_A_only=False, + unstable_idx=None, + update_mask=None, + verbose=True, + apply_output_constraints_to: Optional[List[str]] = None, + initial_As: Optional[dict] = None, + initial_lb: Optional[torch.tensor] = None, + initial_ub: Optional[torch.tensor] = None, +): + use_beta_crown = self.bound_opts['optimize_bound_args']['enable_beta_crown'] + + if bound_node.are_output_constraints_activated_for_layer(apply_output_constraints_to): + assert not use_beta_crown + assert not self.cut_used + assert initial_As is None + assert initial_lb is None + assert initial_ub is None + return self.backward_general_with_output_constraint( + bound_node=bound_node, + C=C, + start_backporpagation_at_node=start_backpropagation_at_node, + bound_lower=bound_lower, + bound_upper=bound_upper, + average_A=average_A, + need_A_only=need_A_only, + unstable_idx=unstable_idx, + update_mask=update_mask, + verbose=verbose, + ) + + roots = self.roots() + + if start_backpropagation_at_node is None: + # When output constraints are used, backward_general_with_output_constraint() + # adds additional layers at the end, performs the backpropagation through these, + # and then calls backward_general() on the output layer. + # In this case, the layer we start from (start_backpropagation_at_node) differs + # from the layer that should be bounded (bound_node) + + # When output constraints are not used, the bounded node is the one where + # backpropagation starts. + start_backpropagation_at_node = bound_node + if verbose: - logger.debug(f'Bound backward from {node.__class__.__name__}({node.name})') + logger.debug(f'Bound backward from {start_backpropagation_at_node.__class__.__name__}({start_backpropagation_at_node.name}) ' + f'to bound {bound_node.__class__.__name__}({bound_node.name})') if isinstance(C, str): logger.debug(f' C: {C}') elif C is not None: logger.debug(f' C: shape {C.shape}, type {type(C)}') - _print_time = False + _print_time = bool(os.environ.get('AUTOLIRPA_PRINT_TIME', 0)) if isinstance(C, str): # If C is a str, use batched CROWN. If batched CROWN is not intended to # be enabled, C must be a explicitly provided non-str object for this function. - if need_A_only or self.return_A or average_A: + if need_A_only or average_A: raise ValueError( 'Batched CROWN is not compatible with ' - f'need_A_only={need_A_only}, return_A={self.return_A}, ' - f'average_A={average_A}') - node.lower, node.upper = self.batched_backward( - node, C, unstable_idx, - batch_size=self.root[0].value.shape[0], + f'need_A_only={need_A_only}, average_A={average_A}') + ret = self.batched_backward( + bound_node, C, unstable_idx, + batch_size=roots[0].value.shape[0], bound_lower=bound_lower, bound_upper=bound_upper, ) - return node.lower, node.upper + bound_node.lower, bound_node.upper = ret[:2] + return ret - for l in self._modules.values(): - l.lA = l.uA = None - l.bounded = True + for n in self.nodes(): + n.lA = n.uA = None - degree_out = get_degrees(node, self.backward_from) - all_nodes_before = list(degree_out.keys()) - - C, batch_size, output_dim, output_shape = preprocess_C(C, node) - - node.lA = C if bound_lower else None - node.uA = C if bound_upper else None - lb = ub = torch.tensor(0., device=self.device) + degree_out = get_degrees(start_backpropagation_at_node) + C, batch_size, output_dim, output_shape = self._preprocess_C(C, bound_node) + if initial_As is None: + start_backpropagation_at_node.lA = C if bound_lower else None + start_backpropagation_at_node.uA = C if bound_upper else None + else: + for layer_name, (lA, uA) in initial_As.items(): + self[layer_name].lA = lA + self[layer_name].uA = uA + assert start_backpropagation_at_node.lA is not None or start_backpropagation_at_node.uA is not None + if initial_lb is None: + lb = torch.tensor(0., device=self.device) + else: + lb = initial_lb + if initial_ub is None: + ub = torch.tensor(0., device=self.device) + else: + ub = initial_ub # Save intermediate layer A matrices when required. A_record = {} - queue = deque([node]) + queue = deque([start_backpropagation_at_node]) while len(queue) > 0: l = queue.popleft() # backward from l - l.bounded = True + self.backward_from[l.name].append(bound_node) - if l.name in self.root_name: continue + if l.name in self.root_names: continue # if all the succeeds are done, then we can turn to this node in the # next iteration. @@ -142,53 +236,69 @@ def backward_general( # A matrices are all zero, no need to propagate. continue - if isinstance(l, BoundRelu): - # TODO: unify this interface. - A, lower_b, upper_b = l.bound_backward( - l.lA, l.uA, *l.inputs, start_node=node, unstable_idx=unstable_idx, - beta_for_intermediate_layers=self.intermediate_constr is not None) - elif isinstance(l, BoundOptimizableActivation): + lA, uA = l.lA, l.uA + if (l.name != start_backpropagation_at_node.name and use_beta_crown + and getattr(l, 'sparse_betas', None)): + lA, uA, lbias, ubias = self.beta_crown_backward_bound( + l, lA, uA, start_node=start_backpropagation_at_node) + lb = lb + lbias + ub = ub + ubias + + if isinstance(l, BoundOptimizableActivation): # For other optimizable activation functions (TODO: unify with ReLU). - if node.name != self.final_node_name: - start_shape = node.output_shape[1:] + if bound_node.name != self.final_node_name: + start_shape = bound_node.output_shape[1:] else: start_shape = C.shape[0] - A, lower_b, upper_b = l.bound_backward( - l.lA, l.uA, *l.inputs, start_shape=start_shape, start_node=node) + l.preserve_mask = update_mask else: - A, lower_b, upper_b = l.bound_backward(l.lA, l.uA, *l.inputs) + start_shape = None + A, lower_b, upper_b = l.bound_backward( + lA, uA, *l.inputs, + start_node=bound_node, unstable_idx=unstable_idx, + start_shape=start_shape) + # After propagation through this node, we delete its lA, uA variables. - if not self.return_A and node.name != self.final_name: + if bound_node.name != self.final_name: del l.lA, l.uA if _print_time: + torch.cuda.synchronize() time_elapsed = time.time() - start_time - if time_elapsed > 1e-3: + if time_elapsed > 5e-3: print(l, time_elapsed) if lb.ndim > 0 and type(lower_b) == Tensor and self.conv_mode == 'patches': lb, ub, lower_b, upper_b = check_patch_biases(lb, ub, lower_b, upper_b) lb = lb + lower_b ub = ub + upper_b - if self.return_A and self.needed_A_dict and node.name in self.needed_A_dict: + if self.return_A and self.needed_A_dict and bound_node.name in self.needed_A_dict: # FIXME remove [0][0] and [0][1]? - if len(self.needed_A_dict[node.name]) == 0 or l.name in self.needed_A_dict[node.name]: - A_record.update({l.name: { - "lA": A[0][0].transpose(0, 1).detach() if A[0][0] is not None else None, - "uA": A[0][1].transpose(0, 1).detach() if A[0][1] is not None else None, - # When not used, lb or ub is tensor(0). - "lbias": lb.transpose(0, 1).detach() if lb.ndim > 1 else None, - "ubias": ub.transpose(0, 1).detach() if ub.ndim > 1 else None, - "unstable_idx": unstable_idx + if len(self.needed_A_dict[bound_node.name]) == 0 or l.name in self.needed_A_dict[bound_node.name]: + # A could be either patches (in this case we cannot transpose so directly return) + # or matrix (in this case we transpose) + A_record.update({ + l.name: { + "lA": ( + A[0][0] if isinstance(A[0][0], Patches) + else A[0][0].transpose(0, 1).detach() + ) if A[0][0] is not None else None, + "uA": ( + A[0][1] if isinstance(A[0][1], Patches) + else A[0][1].transpose(0, 1).detach() + ) if A[0][1] is not None else None, + # When not used, lb or ub is tensor(0). + "lbias": lb.transpose(0, 1).detach() if lb.ndim > 1 else None, + "ubias": ub.transpose(0, 1).detach() if ub.ndim > 1 else None, + "unstable_idx": unstable_idx }}) # FIXME: solve conflict with the following case - self.A_dict.update({node.name: A_record}) - if need_A_only and set(self.needed_A_dict[node.name]) == set(A_record.keys()): + self.A_dict.update({bound_node.name: A_record}) + if need_A_only and set(self.needed_A_dict[bound_node.name]) == set(A_record.keys()): # We have collected all A matrices we need. We can return now! - self.A_dict.update({node.name: A_record}) + self.A_dict.update({bound_node.name: A_record}) # Do not concretize to save time. We just need the A matrices. - # return A matrix as a dict: {node.name: [A_lower, A_upper]} + # return A matrix as a dict: {node_start.name: [A_lower, A_upper]} return None, None, self.A_dict - for i, l_pre in enumerate(l.inputs): add_bound(l, l_pre, lA=A[i][0], uA=A[i][1]) @@ -197,26 +307,29 @@ def backward_general( if ub.ndim >= 2: ub = ub.transpose(0, 1) - if self.return_A and self.needed_A_dict and node.name in self.needed_A_dict: + if self.return_A and self.needed_A_dict and bound_node.name in self.needed_A_dict: save_A_record( - node, A_record, self.A_dict, self.root, self.needed_A_dict[node.name], + bound_node, A_record, self.A_dict, roots, + self.needed_A_dict[bound_node.name], lb=lb, ub=ub, unstable_idx=unstable_idx) # TODO merge into `concretize` - if self.cut_used and getattr(self, 'cut_module', None) is not None and self.cut_module.x_coeffs is not None: + if (self.cut_used and getattr(self, 'cut_module', None) is not None + and self.cut_module.x_coeffs is not None): # propagate input neuron in cut constraints - self.root[0].lA, self.root[0].uA = self.cut_module.input_cut( - node, self.root[0].lA, self.root[0].uA, self.root[0].lower.size()[1:], unstable_idx, + roots[0].lA, roots[0].uA = self.cut_module.input_cut( + bound_node, roots[0].lA, roots[0].uA, roots[0].lower.size()[1:], unstable_idx, batch_mask=update_mask) - lb, ub = concretize( - lb, ub, node, self.root, batch_size, output_dim, - bound_lower, bound_upper, average_A=average_A) + lb, ub = concretize(self, batch_size, output_dim, lb, ub, + bound_lower, bound_upper, + average_A=average_A, node_start=bound_node) # TODO merge into `concretize` - if self.cut_used and getattr(self, "cut_module", None) is not None and self.cut_module.cut_bias is not None: + if (self.cut_used and getattr(self, "cut_module", None) is not None + and self.cut_module.cut_bias is not None): # propagate cut bias in cut constraints - lb, ub = self.cut_module.bias_cut(node, lb, ub, unstable_idx, batch_mask=update_mask) + lb, ub = self.cut_module.bias_cut(bound_node, lb, ub, unstable_idx, batch_mask=update_mask) if lb is not None and ub is not None and ((lb-ub)>0).sum().item() > 0: # make sure there is no bug for cut constraints propagation print(f"Warning: lb is larger than ub with diff: {(lb-ub)[(lb-ub)>0].max().item()}") @@ -240,23 +353,26 @@ def get_unstable_size(unstable_idx): return unstable_idx.numel() -def check_optimized_variable_sparsity(self, node): - alpha_sparsity = None # unknown. +def check_optimized_variable_sparsity(self: 'BoundedModule', node): + alpha_sparsity = None # unknown, optimizable variables are not created for this node. for relu in self.relus: - if hasattr(relu, 'alpha_lookup_idx') and node.name in relu.alpha_lookup_idx: + # FIXME: this hardcoded for ReLUs. Need to support other optimized nonlinear functions. + # alpha_lookup_idx is only created for sparse-spec alphas. + if relu.alpha_lookup_idx is not None and node.name in relu.alpha_lookup_idx: if relu.alpha_lookup_idx[node.name] is not None: # This node was created with sparse alpha alpha_sparsity = True + elif self.bound_opts['optimize_bound_args']['use_shared_alpha']: + # Shared alpha, the spec dimension is 1, and sparsity can be supported. + alpha_sparsity = True else: alpha_sparsity = False break - # print(f'node {node.name} alpha sparsity {alpha_sparsity}') return alpha_sparsity -def get_sparse_C( - self, node, sparse_intermediate_bounds=True, ref_intermediate_lb=None, - ref_intermediate_ub=None): +def get_sparse_C(self: 'BoundedModule', node, sparse_intermediate_bounds=True, + ref_intermediate_lb=None, ref_intermediate_ub=None): sparse_conv_intermediate_bounds = self.bound_opts.get('sparse_conv_intermediate_bounds', False) minimum_sparsity = self.bound_opts.get('minimum_sparsity', 0.9) crown_batch_size = self.bound_opts.get('crown_batch_size', 1e9) @@ -276,8 +392,10 @@ def get_sparse_C( if (isinstance(node, BoundLinear) or isinstance(node, BoundMatMul)) and int( os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0: if sparse_intermediate_bounds: - # If we are doing bound refinement and reference bounds are given, we only refine unstable neurons. - # Also, if we are checking against LP solver we will refine all neurons and do not use this optimization. + # If we are doing bound refinement and reference bounds are given, + # we only refine unstable neurons. + # Also, if we are checking against LP solver we will refine all + # neurons and do not use this optimization. # For each batch element, we find the unstable neurons. unstable_idx, unstable_size = self.get_unstable_locations( ref_intermediate_lb, ref_intermediate_ub) @@ -289,9 +407,12 @@ def get_sparse_C( # Create C in batched CROWN newC = 'OneHot' reduced_dim = True - elif (unstable_size <= minimum_sparsity * dim and unstable_size > 0 and alpha_is_sparse is None) or alpha_is_sparse: - # When we already have sparse alpha for this layer, we always use sparse C. Otherwise we determine it by sparsity. - # Create an abstract C matrix, the unstable_idx are the non-zero elements in specifications for all batches. + elif ((0 < unstable_size <= minimum_sparsity * dim + and alpha_is_sparse is None) or alpha_is_sparse): + # When we already have sparse alpha for this layer, we always + # use sparse C. Otherwise we determine it by sparsity. + # Create an abstract C matrix, the unstable_idx are the non-zero + # elements in specifications for all batches. newC = OneHotC( [batch_size, unstable_size, *node.output_shape[1:]], self.device, unstable_idx, None) @@ -300,7 +421,10 @@ def get_sparse_C( unstable_idx = None del ref_intermediate_lb, ref_intermediate_ub if not reduced_dim: - newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device) + if dim > crown_batch_size: + newC = 'eye' + else: + newC = eyeC([batch_size, dim, *node.output_shape[1:]], self.device) elif node.patches_start and node.mode == "patches": if sparse_intermediate_bounds: unstable_idx, unstable_size = self.get_unstable_locations( @@ -323,7 +447,8 @@ def get_sparse_C( # elements in specifications for all batches. # The shape of patches is [unstable_size, batch, C, H, W]. newC = Patches( - shape=[unstable_size, batch_size, *node.output_shape[1:-2], 1, 1], + shape=[unstable_size, batch_size, *node.output_shape[1:-2], + 1, 1], identity=1, unstable_idx=unstable_idx, output_shape=[batch_size, *node.output_shape[1:]]) reduced_dim = True @@ -336,7 +461,8 @@ def get_sparse_C( None, 1, 0, [node.output_shape[1], batch_size, *node.output_shape[2:], *node.output_shape[1:-2], 1, 1], 1, output_shape=[batch_size, *node.output_shape[1:]]) - elif isinstance(node, (BoundAdd, BoundSub)) and node.mode == "patches": + elif (isinstance(node, (BoundAdd, BoundSub)) and node.mode == "patches" + and len(node.output_shape) >= 4): # FIXME: BoundAdd does not always have patches. Need to use a better way # to determine patches mode. # FIXME: We should not hardcode BoundAdd here! @@ -361,9 +487,12 @@ def get_sparse_C( dtype=list(self.parameters())[0].dtype)).view( num_channel, 1, 1, 1, num_channel, 1, 1) # Expand to (out_c, 1, unstable_size, out_c, 1, 1). - patches = patches.expand(-1, 1, node.output_shape[-2], node.output_shape[-1], -1, 1, 1) - patches = patches[unstable_idx[0], :, unstable_idx[1], unstable_idx[2]] - # Expand with the batch dimension. Final shape (unstable_size, batch_size, out_c, 1, 1). + patches = patches.expand(-1, 1, node.output_shape[-2], + node.output_shape[-1], -1, 1, 1) + patches = patches[unstable_idx[0], :, + unstable_idx[1], unstable_idx[2]] + # Expand with the batch dimension. Final shape + # (unstable_size, batch_size, out_c, 1, 1). patches = patches.expand(-1, batch_size, -1, -1, -1) newC = Patches( patches, 1, 0, patches.shape, unstable_idx=unstable_idx, @@ -380,8 +509,10 @@ def get_sparse_C( dtype=list(self.parameters())[0].dtype)).view( num_channel, 1, 1, 1, num_channel, 1, 1) # Expand to (out_c, batch, out_h, out_w, out_c, 1, 1). - patches = patches.expand(-1, batch_size, node.output_shape[-2], node.output_shape[-1], -1, 1, 1) - newC = Patches(patches, 1, 0, patches.shape, output_shape=[batch_size, *node.output_shape[1:]]) + patches = patches.expand(-1, batch_size, node.output_shape[-2], + node.output_shape[-1], -1, 1, 1) + newC = Patches(patches, 1, 0, patches.shape, output_shape=[ + batch_size, *node.output_shape[1:]]) else: if sparse_intermediate_bounds: unstable_idx, unstable_size = self.get_unstable_locations( @@ -394,28 +525,37 @@ def get_sparse_C( # Create in C in batched CROWN newC = 'eye' reduced_dim = True - elif (unstable_size <= minimum_sparsity * dim and alpha_is_sparse is None) or alpha_is_sparse: + elif (unstable_size <= minimum_sparsity * dim + and alpha_is_sparse is None) or alpha_is_sparse: newC = torch.zeros([1, unstable_size, dim], device=self.device) # Fill the corresponding elements to 1.0 newC[0, torch.arange(unstable_size), unstable_idx] = 1.0 - newC = newC.expand(batch_size, -1, -1).view(batch_size, unstable_size, *node.output_shape[1:]) + newC = newC.expand(batch_size, -1, -1).view( + batch_size, unstable_size, *node.output_shape[1:]) reduced_dim = True else: unstable_idx = None del ref_intermediate_lb, ref_intermediate_ub if not reduced_dim: if dim > 1000: - warnings.warn(f"Creating an identity matrix with size {dim}x{dim} for node {node}. This may indicate poor performance for bound computation. If you see this message on a small network please submit a bug report.", stacklevel=2) - newC = torch.eye(dim, device=self.device, dtype=list(self.parameters())[0].dtype) \ - .unsqueeze(0).expand(batch_size, -1, -1) \ - .view(batch_size, dim, *node.output_shape[1:]) + warnings.warn( + f"Creating an identity matrix with size {dim}x{dim} for node {node}. " + "This may indicate poor performance for bound computation. " + "If you see this message on a small network please submit " + "a bug report.", stacklevel=2) + if dim > crown_batch_size: + newC = 'eye' + else: + newC = torch.eye(dim, device=self.device).unsqueeze(0).expand( + batch_size, -1, -1 + ).view(batch_size, dim, *node.output_shape[1:]) return newC, reduced_dim, unstable_idx, unstable_size -def restore_sparse_bounds( - self, node, unstable_idx, unstable_size, ref_intermediate_lb, ref_intermediate_ub, - new_lower=None, new_upper=None): +def restore_sparse_bounds(self: 'BoundedModule', node, unstable_idx, + unstable_size, ref_intermediate_lb, + ref_intermediate_ub, new_lower=None, new_upper=None): batch_size = self.batch_size if unstable_size == 0: # No unstable neurons. Skip the update. @@ -447,22 +587,26 @@ def restore_sparse_bounds( node.upper = upper.view(batch_size, *node.output_shape[1:]) -def get_degrees(node_start, backward_from): +def get_degrees(node_start): + if not isinstance(node_start, list): + node_start = [node_start] degrees = {} - queue = deque([node_start]) - node_start.bounded = False + added = {} + queue = deque() + for node in node_start: + queue.append(node) + added[node.name] = True while len(queue) > 0: l = queue.popleft() - backward_from[l.name].append(node_start) for l_pre in l.inputs: degrees[l_pre.name] = degrees.get(l_pre.name, 0) + 1 - if l_pre.bounded: - l_pre.bounded = False + if not added.get(l_pre.name, False): queue.append(l_pre) + added[l_pre.name] = True return degrees -def preprocess_C(C, node): +def _preprocess_C(self: 'BoundedModule', C, node): if isinstance(C, Patches): if C.unstable_idx is None: # Patches have size (out_c, batch, out_h, out_w, c, h, w). @@ -478,9 +622,11 @@ def preprocess_C(C, node): else: batch_size, output_dim = C.shape[:2] - # The C matrix specified by the user has shape (batch, spec) but internally we have (spec, batch) format. + # The C matrix specified by the user has shape (batch, spec) + # but internally we have (spec, batch) format. if not isinstance(C, (eyeC, Patches, OneHotC)): - C = C.transpose(0, 1) + C = C.transpose(0, 1).reshape( + output_dim, batch_size, *node.output_shape[1:]) elif isinstance(C, eyeC): C = C._replace(shape=(C.shape[1], C.shape[0], *C.shape[2:])) elif isinstance(C, OneHotC): @@ -506,65 +652,61 @@ def preprocess_C(C, node): return C, batch_size, output_dim, output_shape -def concretize(lb, ub, node, root, batch_size, output_dim, bound_lower=True, bound_upper=True, average_A=False): - - for i in range(len(root)): - if root[i].lA is None and root[i].uA is None: continue - if average_A and isinstance(root[i], BoundParams): - lA = root[i].lA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].lA.shape) if bound_lower else None - uA = root[i].uA.mean(node.batch_dim + 1, keepdim=True).expand(root[i].uA.shape) if bound_upper else None +def concretize(self, batch_size, output_dim, lb, ub=None, + bound_lower=True, bound_upper=True, + average_A=False, node_start=None): + roots = self.roots() + for i in range(len(roots)): + if roots[i].lA is None and roots[i].uA is None: continue + if average_A and isinstance(roots[i], BoundParams): + lA = roots[i].lA.mean( + node_start.batch_dim + 1, keepdim=True + ).expand(roots[i].lA.shape) if bound_lower else None + uA = roots[i].uA.mean( + node_start.batch_dim + 1, keepdim=True + ).expand(roots[i].uA.shape) if bound_upper else None else: - lA, uA = root[i].lA, root[i].uA - if not isinstance(root[i].lA, eyeC) and not isinstance(root[i].lA, Patches): - lA = root[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None - if not isinstance(root[i].uA, eyeC) and not isinstance(root[i].uA, Patches): - uA = root[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None - if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: - if isinstance(root[i], BoundParams): + lA, uA = roots[i].lA, roots[i].uA + if not isinstance(roots[i].lA, eyeC) and not isinstance(roots[i].lA, Patches): + lA = roots[i].lA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_lower else None + if not isinstance(roots[i].uA, eyeC) and not isinstance(roots[i].uA, Patches): + uA = roots[i].uA.reshape(output_dim, batch_size, -1).transpose(0, 1) if bound_upper else None + if hasattr(roots[i], 'perturbation') and roots[i].perturbation is not None: + if isinstance(roots[i], BoundParams): # add batch_size dim for weights node - lb = lb + root[i].perturbation.concretize( - root[i].center.unsqueeze(0), lA, - sign=-1, aux=root[i].aux) if bound_lower else None - ub = ub + root[i].perturbation.concretize( - root[i].center.unsqueeze(0), uA, - sign=+1, aux=root[i].aux) if bound_upper else None + lb = lb + roots[i].perturbation.concretize( + roots[i].center.unsqueeze(0), lA, + sign=-1, aux=roots[i].aux) if bound_lower else None + ub = ub + roots[i].perturbation.concretize( + roots[i].center.unsqueeze(0), uA, + sign=+1, aux=roots[i].aux) if bound_upper else None else: - lb = lb + root[i].perturbation.concretize( - root[i].center, lA, sign=-1, aux=root[i].aux) if bound_lower else None - ub = ub + root[i].perturbation.concretize( - root[i].center, uA, sign=+1, aux=root[i].aux) if bound_upper else None + lb = lb + roots[i].perturbation.concretize( + roots[i].center, lA, sign=-1, aux=roots[i].aux) if bound_lower else None + ub = ub + roots[i].perturbation.concretize( + roots[i].center, uA, sign=+1, aux=roots[i].aux) if bound_upper else None else: - fv = root[i].forward_value - if type(root[i]) == BoundInput: + fv = roots[i].forward_value + if type(roots[i]) == BoundInput: # Input node with a batch dimension batch_size_ = batch_size else: # Parameter node without a batch dimension batch_size_ = 1 - if bound_lower: - if isinstance(lA, eyeC): - lb = lb + fv.view(batch_size_, -1) - elif isinstance(lA, Patches): - lb = lb + lA.matmul(fv, input_shape=root[0].center.shape) - elif type(root[i]) == BoundInput: - lb = lb + lA.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1) + def _add_constant(A, b): + if isinstance(A, eyeC): + b = b + fv.view(batch_size_, -1) + elif isinstance(A, Patches): + b = b + A.matmul(fv, input_shape=roots[0].center.shape) + elif type(roots[i]) == BoundInput: + b = b + A.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1) else: - lb = lb + lA.matmul(fv.view(-1, 1)).squeeze(-1) - else: - lb = None - - if bound_upper: - if isinstance(uA, eyeC): - ub = ub + fv.view(batch_size_, -1) - elif isinstance(uA, Patches): - ub = ub + uA.matmul(fv, input_shape=root[0].center.shape) - elif type(root[i]) == BoundInput: - ub = ub + uA.matmul(fv.view(batch_size_, -1, 1)).squeeze(-1) - else: - ub = ub + uA.matmul(fv.view(-1, 1)).squeeze(-1) - else: - ub = None + b = b + A.matmul(fv.view(-1, 1)).squeeze(-1) + return b + + lb = _add_constant(lA, lb) if bound_lower else None + ub = _add_constant(uA, ub) if bound_upper else None return lb, ub @@ -581,7 +723,7 @@ def addA(A1, A2): raise NotImplementedError(f'Unsupported types for A1 ({type(A1)}) and A2 ({type(A2)}') -def add_bound(node, node_pre, lA, uA): +def add_bound(node, node_pre, lA=None, uA=None): """Propagate lA and uA to a preceding node.""" if lA is not None: if node_pre.lA is None: @@ -602,25 +744,6 @@ def add_bound(node, node_pre, lA, uA): node_pre.uA = addA(node_pre.uA, uA) -def get_beta_watch_list(intermediate_constr, all_nodes_before): - beta_watch_list = defaultdict(dict) - if intermediate_constr is not None: - # Intermediate layer betas are handled in two cases. - # First, if the beta split is before this node, we don't need to do anything special; - # it will done in BoundRelu. - # Second, if the beta split after this node, we need to modify the A matrix - # during bound propagation to reflect beta after this layer. - for k in intermediate_constr: - if k not in all_nodes_before: - # The second case needs special care: we add all such splits in a watch list. - # However, after first occurance of a layer in the watchlist, - # beta_watch_list will be deleted and the A matrix from split constraints - # has been added and will be propagated to later layers. - for kk, vv in intermediate_constr[k].items(): - beta_watch_list[kk][k] = vv - return beta_watch_list - - def add_constant_node(lb, ub, node): new_lb = node.get_bias(node.lA, node.forward_value) new_ub = node.get_bias(node.uA, node.forward_value) @@ -633,27 +756,27 @@ def add_constant_node(lb, ub, node): return lb, ub -def save_A_record(node, A_record, A_dict, root, needed_A_dict, lb, ub, unstable_idx): +def save_A_record(node, A_record, A_dict, roots, needed_A_dict, lb, ub, unstable_idx): root_A_record = {} - for i in range(len(root)): - if root[i].lA is None and root[i].uA is None: continue - if root[i].name in needed_A_dict: - if root[i].lA is not None: - if isinstance(root[i].lA, Patches): - _lA = root[i].lA + for i in range(len(roots)): + if roots[i].lA is None and roots[i].uA is None: continue + if roots[i].name in needed_A_dict: + if roots[i].lA is not None: + if isinstance(roots[i].lA, Patches): + _lA = roots[i].lA else: - _lA = root[i].lA.transpose(0, 1).detach() + _lA = roots[i].lA.transpose(0, 1).detach() else: _lA = None - if root[i].uA is not None: - if isinstance(root[i].uA, Patches): - _uA = root[i].uA + if roots[i].uA is not None: + if isinstance(roots[i].uA, Patches): + _uA = roots[i].uA else: - _uA = root[i].uA.transpose(0, 1).detach() + _uA = roots[i].uA.transpose(0, 1).detach() else: _uA = None - root_A_record.update({root[i].name: { + root_A_record.update({roots[i].name: { "lA": _lA, "uA": _uA, # When not used, lb or ub is tensor(0). They have been transposed above. @@ -678,8 +801,10 @@ def select_unstable_idx(ref_intermediate_lb, ref_intermediate_ub, unstable_locs, return indices_selected -def get_unstable_locations( - self, ref_intermediate_lb, ref_intermediate_ub, conv=False, channel_only=False): +def get_unstable_locations(self: 'BoundedModule', ref_intermediate_lb, + ref_intermediate_ub, conv=False, channel_only=False): + # FIXME (2023): This function should be a member class of the Bound object, since the + # definition of unstable neurons depends on the activation function. max_crown_size = self.bound_opts.get('max_crown_size', int(1e9)) # For conv layer we only check the case where all neurons are active/inactive. unstable_masks = torch.logical_and(ref_intermediate_lb < 0, ref_intermediate_ub > 0) @@ -694,9 +819,9 @@ def get_unstable_locations( else: if not conv and unstable_masks.ndim > 2: # Flatten the conv layer shape. - unstable_masks = unstable_masks.view(unstable_masks.size(0), -1) - ref_intermediate_lb = ref_intermediate_lb.view(ref_intermediate_lb.size(0), -1) - ref_intermediate_ub = ref_intermediate_ub.view(ref_intermediate_ub.size(0), -1) + unstable_masks = unstable_masks.reshape(unstable_masks.size(0), -1) + ref_intermediate_lb = ref_intermediate_lb.reshape(ref_intermediate_lb.size(0), -1) + ref_intermediate_ub = ref_intermediate_ub.reshape(ref_intermediate_ub.size(0), -1) unstable_locs = unstable_masks.sum(dim=0).bool() if conv: # Now converting it to indices for these unstable nuerons. @@ -719,45 +844,85 @@ def get_unstable_locations( def get_alpha_crown_start_nodes( - self, node, c=None, share_slopes=False, final_node_name=None): + self: 'BoundedModule', + node, + c=None, + share_alphas=False, + final_node_name=None, + backward_from_node: Bound = None, + ): + """ + Given a layer "node", return a list of following nodes after this node whose bounds + will propagate through this node. Each element in the list is a tuple with 3 elements: + (following_node_name, following_node_shape, unstable_idx) + """ # When use_full_conv_alpha is True, conv layers do not share alpha. sparse_intermediate_bounds = self.bound_opts.get('sparse_intermediate_bounds', False) use_full_conv_alpha_thresh = self.bound_opts.get('use_full_conv_alpha_thresh', 512) start_nodes = [] - for nj in self.backward_from[node.name]: # Pre-activation layers. + # In most cases, backward_from_node == node + # Only if output constraints are used, will they differ: the node that should be + # bounded (node) needs alphas for *all* layers, not just those behind it. + # In this case, backward_from_node will be the input node + if backward_from_node != node: + assert len(self.bound_opts['optimize_bound_args']['apply_output_constraints_to']) > 0 + + for nj in self.backward_from[backward_from_node.name]: # Pre-activation layers. unstable_idx = None - use_sparse_conv = None + use_sparse_conv = None # Whether a sparse-spec alpha is used for a conv output node. None for non-conv output node. use_full_conv_alpha = self.bound_opts.get('use_full_conv_alpha', False) - if (sparse_intermediate_bounds and isinstance(node, BoundRelu) - and nj.name != final_node_name and not share_slopes): + + # Find the indices of unstable neuron, used for create sparse-feature alpha. + if (sparse_intermediate_bounds + and isinstance(node, BoundOptimizableActivation) + and nj.name != final_node_name and not share_alphas): # Create sparse optimization variables for intermediate neurons. - if ((isinstance(nj, BoundLinear) or isinstance(nj, BoundMatMul)) - and int(os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0): - # unstable_idx has shape [neuron_size_of_nj]. Batch dimension is reduced. - unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper) - elif isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches': - if nj.name in node.patch_size: - # unstable_idx has shape [channel_size_of_nj]. Batch and spatial dimensions are reduced. - unstable_idx, _ = self.get_unstable_locations( - nj.lower, nj.upper, channel_only=not use_full_conv_alpha, conv=True) - use_sparse_conv = False # alpha is shared among channels. Sparse-spec alpha in hw dimension not used. - if use_full_conv_alpha and unstable_idx[0].size(0) > use_full_conv_alpha_thresh: - # Too many unstable neurons. Using shared alpha per channel. - unstable_idx, _ = self.get_unstable_locations( - nj.lower, nj.upper, channel_only=True, conv=True) - use_full_conv_alpha = False - else: - # matrix mode for conv layers. - # unstable_idx has shape [c_out * h_out * w_out]. Batch dimension is reduced. + # These are called "sparse-spec" alpha because we only create alpha only for + # the intermediate of final output nodes whose bounds are needed. + # "sparse-spec" alpha makes sense only for piece-wise linear functions. + # For other intermediate nodes, there is no "unstable" or "stable" neuron. + # FIXME: whether an layer has unstable/stable neurons should be in Bound obj. + # FIXME: get_unstable_locations should be a member class of ReLU. + if len(nj.output_name) == 1 and isinstance(self[nj.output_name[0]], (BoundRelu, BoundSignMerge, BoundMaxPool)): + if ((isinstance(nj, (BoundLinear, BoundMatMul))) + and int(os.environ.get('AUTOLIRPA_USE_FULL_C', 0)) == 0): + # unstable_idx has shape [neuron_size_of_nj]. Batch dimension is reduced. unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper) - use_sparse_conv = True # alpha is not shared among channels, and is sparse in spec dimension. + elif isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches': + if nj.name in node.patch_size: + # unstable_idx has shape [channel_size_of_nj]. Batch and spatial dimensions are reduced. + unstable_idx, _ = self.get_unstable_locations( + nj.lower, nj.upper, channel_only=not use_full_conv_alpha, conv=True) + use_sparse_conv = False # alpha is shared among channels. Sparse-spec alpha in hw dimension not used. + if use_full_conv_alpha and unstable_idx[0].size(0) > use_full_conv_alpha_thresh: + # Too many unstable neurons. Using shared alpha per channel. + unstable_idx, _ = self.get_unstable_locations( + nj.lower, nj.upper, channel_only=True, conv=True) + use_full_conv_alpha = False + else: + # Matrix mode for conv layers. Although the bound propagation started with patches mode, + # when A matrix is propagated to this layer, it might become a dense matrix since patches + # can be come very large after many layers. In this case, + # unstable_idx has shape [c_out * h_out * w_out]. Batch dimension is reduced. + unstable_idx, _ = self.get_unstable_locations(nj.lower, nj.upper) + use_sparse_conv = True # alpha is not shared among channels, and is sparse in spec dimension. + else: + # FIXME: we should not check for fixed names here. Need to enable patches mode more generally. + if isinstance(nj, (BoundConv, BoundAdd, BoundSub, BoundBatchNormalization)) and nj.mode == 'patches': + use_sparse_conv = False # Sparse-spec alpha can never be used, because it is not a ReLU activation. + if nj.name == final_node_name: + # Final layer, always the number of specs as the shape. size_final = self[final_node_name].output_shape[-1] if c is None else c.size(1) - start_nodes.append((final_node_name, size_final, None)) + # The 4-th element indicates that this start node is the final node, + # which may be utilized by operators that do not know the name of + # the final node. + start_nodes.append((final_node_name, size_final, None, True)) continue - if share_slopes: - # all intermediate neurons from the same layer share the same set of slopes. + + if share_alphas: + # all intermediate neurons from the same layer share the same set of alphas. output_shape = 1 elif isinstance(node, BoundOptimizableActivation) and node.patch_size and nj.name in node.patch_size: # Patches mode. Use output channel size as the spec size. This still shares some alpha, but better than no sharing. @@ -770,8 +935,75 @@ def get_alpha_crown_start_nodes( output_shape = node.patch_size[nj.name][0] assert not sparse_intermediate_bounds or use_sparse_conv is False # Double check our assumption holds. If this fails, then we created wrong shapes for alpha. else: - # Output is linear layer, or patch converted to matrix. + # Output is linear layer (use_sparse_conv = None), or patch converted to matrix (use_sparse_conv = True). assert not sparse_intermediate_bounds or use_sparse_conv is not False # Double check our assumption holds. If this fails, then we created wrong shapes for alpha. output_shape = nj.lower.shape[1:] # FIXME: for non-relu activations it's still expecting a prod. - start_nodes.append((nj.name, output_shape, unstable_idx)) + start_nodes.append((nj.name, output_shape, unstable_idx, False)) return start_nodes + + +def merge_A(batch_A, ret_A): + for key0 in batch_A: + if key0 not in ret_A: ret_A[key0] = {} + for key1 in batch_A[key0]: + value = batch_A[key0][key1] + if key1 not in ret_A[key0]: + # create: + ret_A[key0].update({ + key1: { + "lA": value["lA"], + "uA": value["uA"], + "lbias": value["lbias"], + "ubias": value["ubias"], + "unstable_idx": value["unstable_idx"] + } + }) + elif key0 == node.name: + # merge: + # the batch splitting only happens for current node, i.e., + # for other nodes the returned lA should be the same across different batches + # so no need to repeatly merge them + exist = ret_A[key0][key1] + + if exist["unstable_idx"] is not None: + if isinstance(exist["unstable_idx"], torch.Tensor): + merged_unstable = torch.cat([ + exist["unstable_idx"], + value['unstable_idx']], dim=0) + elif isinstance(exist["unstable_idx"], tuple): + if exist["unstable_idx"]: + merged_unstable = tuple([ + torch.cat([exist["unstable_idx"][idx], + value['unstable_idx'][idx]], dim=0) + for idx in range(len(exist['unstable_idx']))] + ) + else: + merged_unstable = None + else: + raise NotImplementedError( + f'Unsupported type {type(exist["unstable_idx"])}') + else: + merged_unstable = None + merge_dict = {"unstable_idx": merged_unstable} + for name in ["lA", "uA"]: + if exist[name] is not None: + if isinstance(exist[name], torch.Tensor): + # for matrix the spec dim is 1 + merge_dict[name] = torch.cat([exist[name], value[name]], dim=1) + else: + assert isinstance(exist[name], Patches) + # for patches the spec dim`is 0 + merge_dict[name] = exist[name].create_similar( + torch.cat([exist[name].patches, value[name].patches], dim=0), + unstable_idx=merged_unstable + ) + else: + merge_dict[name] = None + for name in ["lbias", "ubias"]: + if exist[name] is not None: + # for bias the spec dim in 1 + merge_dict[name] = torch.cat([exist[name], value[name]], dim=1) + else: + merge_dict[name] = None + ret_A[key0][key1] = merge_dict + return ret_A diff --git a/auto_LiRPA/beta_crown.py b/auto_LiRPA/beta_crown.py index 6be6dc0..ad4295c 100644 --- a/auto_LiRPA/beta_crown.py +++ b/auto_LiRPA/beta_crown.py @@ -1,48 +1,238 @@ +from collections import OrderedDict import torch +from torch import Tensor +from .patches import Patches, inplace_unfold +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule -def beta_bias(self): - batch_size = len(self.relus[-1].split_beta) - batch = int(batch_size/2) - bias = torch.zeros((batch_size, 1), device=self.device) - for m in self.relus: - if not m.used or not m.perturbed: + +class SparseBeta: + def __init__(self, shape, bias=False, betas=None, device='cpu'): + self.device = device + self.val = torch.zeros(shape) + self.loc = torch.zeros(shape, dtype=torch.long, device=device) + self.sign = torch.zeros(shape, device=device) + self.bias = torch.zeros(shape) if bias else None + if betas: + for bi in range(len(betas)): + if betas[bi] is not None: + self.val[bi, :len(betas[bi])] = betas[bi] + self.val = self.val.detach().to( + device, non_blocking=True).requires_grad_() + + def apply_splits(self, history, key): + for bi in range(len(history)): + # Add history splits. (layer, neuron) is the current decision. + split_locs, split_coeffs = history[bi][key][:2] + split_len = len(split_locs) + if split_len > 0: + self.sign[bi, :split_len] = torch.as_tensor( + split_coeffs, device=self.device) + self.loc[bi, :split_len] = torch.as_tensor( + split_locs, device=self.device) + if self.bias is not None: + split_bias = history[bi][key][2] + self.bias[bi, :split_len] = torch.as_tensor( + split_bias, device=self.device) + self.loc = self.loc.to(device=self.device, non_blocking=True) + self.sign = self.sign.to(device=self.device, non_blocking=True) + if self.bias is not None: + self.bias = self.bias.to(device=self.device, non_blocking=True) + + +def get_split_nodes(self: 'BoundedModule', input_split=False): + self.split_nodes = [] + self.split_activations = {} + splittable_activations = self.get_splittable_activations() + self._set_used_nodes(self[self.final_name]) + for layer in self.layers_requiring_bounds: + split_activations_ = [] + for activation_name in layer.output_name: + activation = self[activation_name] + if activation in splittable_activations: + split_activations_.append( + (activation, activation.inputs.index(layer))) + if split_activations_: + self.split_nodes.append(layer) + self.split_activations[layer.name] = split_activations_ + if input_split: + root = self[self.root_names[0]] + if root not in self.split_nodes: + self.split_nodes.append(root) + self.split_activations[root.name] = [] + return self.split_nodes, self.split_activations + + +def set_beta(self: 'BoundedModule', enable_opt_interm_bounds, parameters, + lr_beta, lr_cut_beta, cutter, dense_coeffs_mask): + """ + Set betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases + and best_biases. + """ + coeffs = None + betas = [] + best_betas = OrderedDict() + + # TODO compute only once + self.nodes_with_beta = [] + for node in self.split_nodes: + if not hasattr(node, 'sparse_betas'): continue - if m.split_beta_used: - bias[:batch] = bias[:batch] + m.split_bias*m.split_beta[:batch]*m.split_c[:batch] - bias[batch:] = bias[batch:] + m.split_bias*m.split_beta[batch:]*m.split_c[batch:] - if m.history_beta_used: - bias = bias + (m.new_history_bias*m.new_history_beta*m.new_history_c).sum(1, keepdim=True) - # No single node split here, because single node splits do not have bias. - return bias + self.nodes_with_beta.append(node) + if enable_opt_interm_bounds: + for sparse_beta in node.sparse_betas.values(): + if sparse_beta is not None: + betas.append(sparse_beta.val) + best_betas[node.name] = { + beta_m: sparse_beta.val.detach().clone() + for beta_m, sparse_beta in node.sparse_betas.items() + } + else: + betas.append(node.sparse_betas[0].val) + best_betas[node.name] = node.sparse_betas[0].val.detach().clone() + + # Beta has shape (batch, max_splits_per_layer) + parameters.append({'params': betas.copy(), 'lr': lr_beta, 'batch_dim': 0}) + + if self.cut_used: + self.set_beta_cuts(parameters, lr_cut_beta, betas, best_betas, cutter) + + return betas, best_betas, coeffs, dense_coeffs_mask + + +def set_beta_cuts(self: 'BoundedModule', parameters, lr_cut_beta, betas, + best_betas, cutter): + # also need to optimize cut betas + parameters.append({'params': self.cut_beta_params, + 'lr': lr_cut_beta, 'batch_dim': 0}) + betas += self.cut_beta_params + best_betas['cut'] = [beta.detach().clone() for beta in self.cut_beta_params] + if getattr(cutter, 'opt', False): + parameters.append(cutter.get_parameters()) + + +def reset_beta(self: 'BoundedModule', node, shape, betas, bias=False, + start_nodes=None): + # Create only the non-zero beta. For each layer, it is padded to maximal length. + # We create tensors on CPU first, and they will be transferred to GPU after initialized. + if self.bound_opts.get('enable_opt_interm_bounds', False): + node.sparse_betas = { + key: SparseBeta( + shape, + betas=[(betas[j][i] if betas[j] is not None else None) + for j in range(len(betas))], + device=self.device, bias=bias, + ) for i, key in enumerate(start_nodes) + } + else: + node.sparse_betas = [SparseBeta( + shape, betas=betas, device=self.device, bias=bias)] + + +def beta_crown_backward_bound(self: 'BoundedModule', node, lA, uA, start_node=None): + """Update A and bias with Beta-CROWN. + + Must be explicitly called at the end of "bound_backward". + """ + # Regular Beta CROWN with single neuron split + # Each split constraint only has single neuron (e.g., second ReLU neuron > 0). + A = lA if lA is not None else uA + lbias = ubias = 0 + + def _bias_unsupported(): + raise NotImplementedError('Bias for beta not supported in this case.') + + if type(A) is Patches: + if not self.bound_opts.get('enable_opt_interm_bounds', False): + raise NotImplementedError('Sparse beta not supported in the patches mode') + if node.sparse_betas[start_node.name].bias is not None: + _bias_unsupported() + # expand sparse_beta to full beta + beta_values = (node.sparse_betas[start_node.name].val + * node.sparse_betas[start_node.name].sign) + beta_indices = node.sparse_betas[start_node.name].loc + node.masked_beta = torch.zeros(2, *node.shape).reshape(2, -1).to(A.patches.dtype) + node.non_deter_scatter_add( + node.masked_beta, dim=1, index=beta_indices, + src=beta_values.to(node.masked_beta.dtype)) + node.masked_beta = node.masked_beta.reshape(2, *node.shape) + # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W) + A_patches = A.patches + masked_beta_unfolded = inplace_unfold( + node.masked_beta, kernel_size=A_patches.shape[-2:], + padding=A.padding, stride=A.stride, + inserted_zeros=A.inserted_zeros, output_padding=A.output_padding) + if A.unstable_idx is not None: + masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4, 5) + # After selection, the shape is (unstable_size, batch, in_c, H, W). + masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]] + else: + # Add the spec (out_c) dimension. + masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0) + if node.alpha_beta_update_mask is not None: + masked_beta_unfolded = masked_beta_unfolded[node.alpha_beta_update_mask] + if uA is not None: + uA = uA.create_similar(uA.patches + masked_beta_unfolded) + if lA is not None: + lA = lA.create_similar(lA.patches - masked_beta_unfolded) + elif type(A) is Tensor: + if self.bound_opts.get('enable_opt_interm_bounds', False): + if node.sparse_betas[start_node.name].bias is not None: + _bias_unsupported() + # For matrix mode, beta is sparse. + beta_values = ( + node.sparse_betas[start_node.name].val + * node.sparse_betas[start_node.name].sign + ).expand(A.size(0), -1, -1) + # node.single_beta_loc has shape [batch, max_single_split]. + # Need to expand at the specs dimension. + beta_indices = (node.sparse_betas[start_node.name].loc + .unsqueeze(0).expand(A.size(0), -1, -1)) + beta_bias = node.sparse_betas[start_node.name].bias + else: + # For matrix mode, beta is sparse. + beta_values = ( + node.sparse_betas[0].val * node.sparse_betas[0].sign + ).expand(A.size(0), -1, -1) + # self.single_beta_loc has shape [batch, max_single_split]. + # Need to expand at the specs dimension. + beta_indices = node.sparse_betas[0].loc.unsqueeze(0).expand(A.size(0), -1, -1) + beta_bias = node.sparse_betas[0].bias + # For conv layer, the last dimension is flattened in indices. + beta_values = beta_values.to(A.dtype) + if beta_bias is not None: + beta_bias = beta_bias.expand(A.size(0), -1, -1) + if node.alpha_beta_update_mask is not None: + beta_indices = beta_indices[:, node.alpha_beta_update_mask] + beta_values = beta_values[:, node.alpha_beta_update_mask] + if beta_bias is not None: + beta_bias = beta_bias[:, node.alpha_beta_update_mask] + if uA is not None: + uA = node.non_deter_scatter_add( + uA.reshape(uA.size(0), uA.size(1), -1), dim=2, + index=beta_indices, src=beta_values).view(uA.size()) + if lA is not None: + lA = node.non_deter_scatter_add( + lA.reshape(lA.size(0), lA.size(1), -1), dim=2, + index=beta_indices, src=beta_values.neg()).view(lA.size()) + if beta_bias is not None: + bias = (beta_values * beta_bias).sum(dim=-1) + lbias = bias + ubias = -bias + else: + raise RuntimeError(f"Unknown type {type(A)} for A") + + return lA, uA, lbias, ubias -def print_optimized_beta(self, relus, intermediate_beta_enabled=False): +def print_optimized_beta(acts): masked_betas = [] - for model in relus: + for model in acts: masked_betas.append(model.masked_beta) if model.history_beta_used: - print(f"{model.name} history beta", model.new_history_beta.squeeze()) + print(f'{model.name} history beta', model.new_history_beta.squeeze()) if model.split_beta_used: - print(f"{model.name} split beta:", model.split_beta.view(-1)) - print(f"{model.name} bias:", model.split_bias) - - -def save_best_intermediate_betas(self, relus, idx): - for layer in relus: - # The history split and current split is handled seperatedly. - if layer.history_beta_used: - # Each key in history_intermediate_betas for this layer is a dictionary, with all other pre-relu layers' names. - for k, v in layer.history_intermediate_betas.items(): - # This is a tensor with shape (batch, *intermediate_layer_shape, number_of_beta) - self.best_intermediate_betas[layer.name]['history'][k]["lb"][idx] = v["lb"][idx] - self.best_intermediate_betas[layer.name]['history'][k]["ub"][idx] = v["ub"][idx] - if layer.split_beta_used: - for k, v in layer.split_intermediate_betas.items(): - # This is a tensor with shape (batch, *intermediate_layer_shape, 1) - self.best_intermediate_betas[layer.name]['split'][k]["lb"][idx] = v["lb"][idx] - self.best_intermediate_betas[layer.name]['split'][k]["ub"][idx] = v["ub"][idx] - if layer.single_beta_used: - for k, v in layer.single_intermediate_betas.items(): - self.best_intermediate_betas[layer.name]['single'][k]["lb"][idx] = v["lb"][idx] - self.best_intermediate_betas[layer.name]['single'][k]["ub"][idx] = v["ub"][idx] \ No newline at end of file + print(f'{model.name} split beta:', model.split_beta.view(-1)) + print(f'{model.name} bias:', model.split_bias) diff --git a/auto_LiRPA/bound_general.py b/auto_LiRPA/bound_general.py index 3348d5b..43a7455 100644 --- a/auto_LiRPA/bound_general.py +++ b/auto_LiRPA/bound_general.py @@ -1,4 +1,5 @@ import copy +from typing import List import numpy as np import warnings from collections import OrderedDict, deque @@ -13,11 +14,12 @@ from .perturbations import * from .utils import * from .patches import Patches - +from .optimized_bounds import default_optimize_bound_args warnings.simplefilter('once') + class BoundedModule(nn.Module): """Bounded module with support for automatically computing bounds. @@ -70,13 +72,20 @@ def __init__(self, model, global_input, bound_opts=None, 'forward_max_dim': int(1e9), # Do not share alpha for conv layers. 'use_full_conv_alpha': True, + 'disabled_optimization': [], # Threshold for number of unstable neurons for each layer to disable # use_full_conv_alpha. 'use_full_conv_alpha_thresh': 512, 'verbosity': 1 if verbose else 0, + 'optimize_graph': {'optimizer': None}, } default_bound_opts.update(bound_opts) self.bound_opts = default_bound_opts + optimize_bound_args = default_optimize_bound_args + optimize_bound_args.update( + self.bound_opts.get('optimize_bound_args', {})) + self.bound_opts.update({'optimize_bound_args': optimize_bound_args}) + self.verbose = verbose self.custom_ops = custom_ops if custom_ops is not None else {} if device == 'auto': @@ -93,6 +102,7 @@ def __init__(self, model, global_input, bound_opts=None, self.optimizable_activations = [] self.relus = [] # save relu layers for convenience + self.layers_with_constraint = [] state_dict_copy = copy.deepcopy(model.state_dict()) object.__setattr__(self, 'ori_state_dict', state_dict_copy) @@ -102,70 +112,8 @@ def __init__(self, model, global_input, bound_opts=None, self.bound_opts.update({'final_shape': self.final_shape}) self._convert(model, global_input) self._mark_perturbed_nodes() - - # set the default values here - optimize_bound_args = { - 'enable_alpha_crown': True, # Enable optimization of alpha. - 'enable_beta_crown': False, # Enable beta split constraint. - 'iteration': 20, # Number of alpha/beta optimization iterations. - # Share some alpha variables to save memory at the cost of slightly - # looser bounds. - 'use_shared_alpha': False, - # Optimize coeffs during intermediate_refinement. - 'opt_coeffs': False, - # Optimize constraint bias during intermediate_refinement. - 'opt_bias': False, - # Optimizer used for alpha and beta optimization. - 'optimizer': 'adam', - # Save best results of alpha/beta/bounds during optimization. - 'keep_best': True, - # Only optimize bounds of last layer during alpha/beta CROWN. - 'fix_intermediate_layer_bounds': True, - # Learning rate for the optimizable parameter alpha in alpha-CROWN. - 'lr_alpha': 0.5, - # Learning rate for the optimizable parameter beta in beta-CROWN. - 'lr_beta': 0.05, - 'lr_cut_beta': 5e-3, # Learning rate for optimizing cut betas. - # Initial alpha variables by calling CROWN once. - 'init_alpha': True, - # Only split single nodes in branch and bound. - 'single_node_split': True, - # Learning rate for intermediate layer beta for refinement. - 'lr_intermediate_beta': 0.1, - 'lr_coeffs': 0.01, # Learning rate for coeffs for refinement - # Optimize constraint bias in compute bounds. - 'intermediate_beta': False, - # Layers to be refined, separated by commas. - # -1 means preactivation before last relu. - 'intermediate_refinement_layers': [-1], - # When batch size is not 1, this reduction function is applied to - # reduce the bounds into a scalar. - 'loss_reduction_func': reduction_sum, - # Criteria function of early stop. - 'stop_criterion_func': lambda x: False, - # Learning rate decay factor during bounds optimization. - 'lr_decay': 0.98, - # Number of iterations that we will start considering early stop - # if tracking no improvement. - 'early_stop_patience': 10, - # Start to save optimized best bounds - # when current_iteration > int(iteration*start_save_best) - 'start_save_best': 0.5, - # Use double fp (float64) at the last iteration in alpha/beta CROWN. - 'use_float64_in_last_iteration': False, - # Prune verified domain within iteration. - 'pruning_in_iteration': False, - # Percentage of the minimum domains that can apply pruning. - 'pruning_in_iteration_threshold': 0.2, - # For specification that will output multiple bounds for one - # property, we use this function to prune them. - 'multi_spec_keep_func': lambda x: True - } - - # change by bound_opts - optimize_bound_args.update( - self.bound_opts.get('optimize_bound_args', {})) - self.bound_opts.update({'optimize_bound_args': optimize_bound_args}) + self._optimize_graph() + self._expand_jacobian() self.next_split_hint = [] # Split hints, used in beta optimization. # Beta values for all intermediate bounds. @@ -177,48 +125,61 @@ def __init__(self, model, global_input, bound_opts=None, self.cut_used = False # a placeholder for cut timestamp, which would be a non-positive int self.cut_timestamp = -1 - - # List of operators. When we are computing intermediate bounds for these - # ops, we simply use IBP to propagate bounds from its input nodes, - # instead of CROWN. - # NOTE Currently only operators with a single input can be supported. - self.ibp_intermediate = [ - BoundRelu, - BoundNeg, - BoundTranspose, - BoundSin, - BoundCos, - BoundTan, - BoundAtan] - # a placeholder to save the latest samplewise mask for # pruning-in-iteration optimization self.last_update_preserve_mask = None + def nodes(self) -> List[Bound]: + return self._modules.values() + + def get_enabled_opt_act(self): + # Optimizable activations that are actually used and perturbed + return [ + n for n in self.optimizable_activations + if n.used and n.perturbed and not getattr(n, 'is_linear_op', False) + ] + + def get_optimizable_activations(self): + for node in self.nodes(): + if (isinstance(node, BoundOptimizableActivation) + and node.optimizable + and len(getattr(node, 'requires_input_bounds', [])) > 0 + and node not in self.optimizable_activations): + disabled = False + for item in self.bound_opts.get('disable_optimization', []): + if item.lower() in str(type(node)).lower(): + disabled = True + if disabled: + logging.info('Disabled optimization for %s', node) + continue + if node not in self.optimizable_activations: + self.optimizable_activations.append(node) + if isinstance(node, BoundRelu) and node not in self.relus: + self.relus.append(node) + def get_perturbed_optimizable_activations(self): return [n for n in self.optimizable_activations if n.perturbed] def get_splittable_activations(self): - """ - Activation functions that can be split during branch and bound. - """ - # TODO: Add other activation functions in a sysmatic manner. - # Do not hard code relus or other functions. - # A node must be perturbed, otherwise it is a constant and no need - # to perturb. - return [n for n in self.relus if n.perturbed] + """Activation functions that can be split during branch and bound.""" + return [n for n in self.nodes() if n.perturbed and n.splittable] def get_layers_requiring_bounds(self): - """ - Layer names whose intermediate layer bounds are required. - """ + """Layer names whose intermediate layer bounds are required.""" intermediate_layers = [] - for node in self.modules(): + for node in self.nodes(): + if not node.used or not node.perturbed: + continue for i in getattr(node, 'requires_input_bounds', []): input_node = node.inputs[i] - if input_node.used and input_node.perturbed: + if (input_node not in intermediate_layers + and input_node.perturbed): + # If not perturbed, it may not have the batch dimension. + # So we do not include it, and it is unnecessary. intermediate_layers.append(input_node) - + if (node.name in self.layers_with_constraint + and node not in intermediate_layers): + intermediate_layers.append(node) return intermediate_layers def check_incompatible_nodes(self, model): @@ -314,8 +275,7 @@ def register_parameter(self, name, param): if '_parameters' not in self.__dict__: raise AttributeError( 'cannot assign parameter before Module.__init__() call') - - elif not isinstance(name, torch._six.string_classes): + elif not isinstance(name, str): raise TypeError('parameter name should be a string. ' f'Got {torch.typename(name)}') elif name == '': @@ -347,11 +307,12 @@ def load_state_dict(self, state_dict, strict=False): new_dict[self.node_name_map[k]] = v return super().load_state_dict(new_dict, strict=strict) - def _named_members(self, get_members_fn, prefix='', recurse=True): + def _named_members(self, get_members_fn, prefix='', recurse=True, **kwargs): # pylint: disable=unused-argument r"""Helper method for yielding various names + members of modules.""" memo = set() modules = self.named_modules(prefix=prefix) if recurse else [ (prefix, self)] + # TODO: support the "remove_duplicate" argument, new in pytorch 2.0. for module_prefix, module in modules: members = get_members_fn(module) for k, v in members: @@ -366,22 +327,21 @@ def _named_members(self, get_members_fn, prefix='', recurse=True): def train(self, mode=True): super().train(mode) - for node in self._modules.values(): + for node in self.nodes(): node.train(mode=mode) def eval(self): super().eval() - for node in self._modules.values(): + for node in self.nodes(): node.eval() def to(self, *args, **kwargs): # Moves and/or casts some attributes except pytorch will do by default. - for node in self._modules.values(): + for node in self.nodes(): for attr in ['lower', 'upper', 'forward_value', 'd', 'lA',]: if hasattr(node, attr): this_attr = getattr(node, attr) if isinstance(this_attr, torch.Tensor): - # print(node, attr) this_attr = this_attr.to(*args, **kwargs) setattr(node, attr, this_attr) @@ -394,7 +354,13 @@ def to(self, *args, **kwargs): return super().to(*args, **kwargs) def __getitem__(self, name): - return self._modules[name] + module = self._modules[name] + # We never create modules that are None, the assert fixes type hints + assert module is not None + return module + + def roots(self): + return [self[name] for name in self.root_names] def final_node(self): return self[self.final_name] @@ -446,31 +412,15 @@ def forward(self, *x, final_node_name=None, clear_forward_only=False, output: The output of the model, or if `final_node_name` is not `None`, return the value on the corresponding node instead. """ - self._set_input(*x, clear_forward_only=clear_forward_only, + self.set_input(*x, clear_forward_only=clear_forward_only, reset_perturbed_nodes=reset_perturbed_nodes) if final_node_name: return self.get_forward_value(self[final_node_name]) else: - out = deque([self.get_forward_value(self[n]) - for n in self.output_name]) - - def _fill_template(template): - if template is None: - return out.popleft() - elif isinstance(template, (list, tuple)): - res = [] - for t in template: - res.append(_fill_template(t)) - return tuple(res) if isinstance(template, tuple) else res - elif isinstance(template, dict): - res = {} - for key in template: - res[key] = _fill_template(template[key]) - return res - else: - raise NotImplementedError - - return _fill_template(self.output_template) + return fill_template( + deque([self.get_forward_value(self[n]) + for n in self.output_name]), + self.output_template) def _mark_perturbed_nodes(self): """Mark the graph nodes and determine which nodes need perturbation.""" @@ -491,12 +441,7 @@ def _mark_perturbed_nodes(self): # BoundParams object, depending on ptb. for name_next in node.output_name: node_next = self[name_next] - if isinstance(node, BoundShape): - # Some nodes like Shape, even connected, - # do not really propagate bounds. - # TODO: make this a property of node? - pass - else: + if not node_next.never_perturbed: # The next node is perturbed if it is already perturbed, # or this node is perturbed. node_next.perturbed = node_next.perturbed or node.perturbed @@ -505,31 +450,17 @@ def _mark_perturbed_nodes(self): # now put it in queue. if degree_in[name_next] == 0: queue.append(node_next) - # Check whether weights are perturbed - if isinstance(node, (BoundLinear, BoundConv, BoundBatchNormalization)): - weight_perturbation = False - for n in node.inputs[1:]: - if hasattr(n, 'perturbation'): - if n.perturbation is not None: - weight_perturbation = True - if weight_perturbation: - node.requires_input_bounds = list(range(len(node.inputs))) - else: - if isinstance(node, BoundMatMul): - node.requires_input_bounds = [1] - else: - node.requires_input_bounds = [] + node.update_requires_input_bounds() - self.splittable_activations =( - self.get_splittable_activations()) + self.get_optimizable_activations() + self.splittable_activations = self.get_splittable_activations() self.perturbed_optimizable_activations = ( - self.get_perturbed_optimizable_activations()) + self.get_perturbed_optimizable_activations()) return - def _clear_and_set_new( - self, intermediate_layer_bounds, clear_forward_only=False, - reset_perturbed_nodes=True): - for l in self._modules.values(): + def _clear_and_set_new(self, interm_bounds, clear_forward_only=False, + reset_perturbed_nodes=True): + for l in self.nodes(): if hasattr(l, 'linear'): if isinstance(l.linear, tuple): for item in l.linear: @@ -543,22 +474,40 @@ def _clear_and_set_new( if hasattr(l, 'forward_value'): delattr(l, 'forward_value') else: - for attr in [ - 'lower', 'upper', 'interval', 'forward_value', 'd', - 'lA', 'lower_d']: + for attr in ['lower', 'upper', 'interval', 'forward_value', 'd', + 'lA', 'lower_d']: if hasattr(l, attr): + # If we use output constraints to tighten bounds, the bound + # computation of every layer will begin at the output layer. + # Thus, it will require to backpropagate through ReLUs *behind* + # the currently bounded one. For those ReLUs, the relaxation + # must depend on the lower and upper bounds of the previous + # iteration. Usually, those would be deleted here, so we must + # save them. + # Keeping them as the .lower and .upper parameters is not + # possible, as the rest of the framework assumes that layers + # that were not bounded in this iteration do not have those + # parameters. + apply_output_constraints_to = ( + self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] + ) + if ( + apply_output_constraints_to is not None and + len(apply_output_constraints_to) > 0 and + attr in ['lower', 'upper'] and + isinstance(getattr(l, attr), torch.Tensor) + ): + setattr(l, f'previous_iteration_{attr}', getattr(l, attr).detach()) delattr(l, attr) - for attr in [ - 'zero_backward_coeffs_l', 'zero_backward_coeffs_u', - 'zero_lA_mtx', 'zero_uA_mtx']: + for attr in ['zero_backward_coeffs_l', 'zero_backward_coeffs_u', + 'zero_lA_mtx', 'zero_uA_mtx']: setattr(l, attr, False) # Given an interval here to make IBP/CROWN start from this node - if (intermediate_layer_bounds is not None - and l.name in intermediate_layer_bounds.keys()): - l.interval = tuple(intermediate_layer_bounds[l.name][:2]) - l.lower = intermediate_layer_bounds[l.name][0] - l.upper = intermediate_layer_bounds[l.name][1] + if interm_bounds is not None and l.name in interm_bounds.keys(): + l.interval = tuple(interm_bounds[l.name][:2]) + l.lower = interm_bounds[l.name][0] + l.upper = interm_bounds[l.name][1] if l.lower is not None: l.lower = l.lower.detach().requires_grad_(False) if l.upper is not None: @@ -571,11 +520,10 @@ def _clear_and_set_new( # Clear operator-specific attributes l.clear() - def _set_input( - self, *x, intermediate_layer_bounds=None, - clear_forward_only=False, reset_perturbed_nodes=True): + def set_input(self, *x, interm_bounds=None, + clear_forward_only=False, reset_perturbed_nodes=True): self._clear_and_set_new( - intermediate_layer_bounds=intermediate_layer_bounds, + interm_bounds=interm_bounds, clear_forward_only=clear_forward_only, reset_perturbed_nodes=reset_perturbed_nodes) inputs_unpacked = unpack_inputs(x) @@ -594,7 +542,6 @@ def _set_input( def _get_node_input(self, nodesOP, nodesIn, node): ret = [] - ori_names = [] for i in range(len(node.inputs)): for op in nodesOP: if op.name == node.inputs[i]: @@ -605,7 +552,6 @@ def _get_node_input(self, nodesOP, nodesIn, node): for io in nodesIn: if io.name == node.inputs[i]: ret.append(io.bound_node) - ori_names.append(io.ori_name) break if len(ret) <= i: raise ValueError(f'cannot find inputs of node: {node.name}') @@ -669,9 +615,8 @@ def _convert_nodes(self, model, global_input): bound_class = BoundParams if isinstance( nodesIn[i].param, nn.Parameter) else BoundBuffers nodesIn[i] = nodesIn[i]._replace(bound_node=bound_class( - ori_name=nodesIn[i].ori_name, - value=nodesIn[i].param, - perturbation=nodesIn[i].perturbation)) + ori_name=nodesIn[i].ori_name, value=nodesIn[i].param, + perturbation=nodesIn[i].perturbation, options=self.bound_opts)) unsupported_ops = [] @@ -695,7 +640,6 @@ def _convert_nodes(self, model, global_input): logger.error('The node has an unsupported operation: %s', nodesOP[n]) continue - attr['device'] = self.device # FIXME generalize @@ -735,41 +679,14 @@ def _build_graph(self, nodesOP, nodesIn, nodesOut, template): # output element. In this case, we are assuming that we aim to compute # the bounds for the first output element by default. self.final_name = nodesOut[0].name - self.input_name, self.input_index, self.root_name = [], [], [] + self.input_name, self.input_index, self.root_names = [], [], [] self.output_name = [n.name for n in nodesOut] self.output_template = template for node in nodesIn: self.add_input_node(node, index=node.input_index) self.add_nodes(nodesOP) if self.conv_mode == 'patches': - self.root_name = [node.name for node in nodesIn] - - # Make sure the nodes already have `name` and `input_name` - def add_nodes(self, nodes): - nodes = [(node if isinstance(node, Bound) else node.bound_node) - for node in nodes] - for node in nodes: - self._modules[node.name] = node - node.output_name = [] - if len(node.inputs) == 0: - self.root_name.append(node.name) - for node in nodes: - for l_pre in node.inputs: - l_pre.output_name.append(node.name) - for node in nodes: - if isinstance(node, BoundOptimizableActivation): - self.optimizable_activations.append(node) - if isinstance(node, BoundRelu): - self.relus.append(node) - - def add_input_node(self, node, index=None): - self.add_nodes([node]) - self.input_name.append(node.name) - # default value for input_index - if index == 'auto': - index = max([0] + [(i + 1) - for i in self.input_index if i is not None]) - self.input_index.append(index) + self.root_names: List[str] = [node.name for node in nodesIn] def rename_nodes(self, nodesOP, nodesIn, rename_dict): def rename(node): @@ -827,8 +744,8 @@ def _split_complex(self, nodesOP, nodesIn): def _get_node_name_map(self): """Build a dict with {ori_name: name, name: ori_name}""" self.node_name_map = {} - for node in self._modules.values(): - if isinstance(node, BoundInput) or isinstance(node, BoundParams): + for node in self.nodes(): + if isinstance(node, (BoundInput, BoundParams)): for p in list(node.named_parameters()): if node.ori_name not in self.node_name_map: name = f'{node.name}.{p[0]}' @@ -881,13 +798,17 @@ def check_prior_bounds(self, node): return for n in node.inputs: self.check_prior_bounds(n) - for i in getattr(node, 'requires_input_bounds', []): - self.compute_intermediate_bounds( - node.inputs[i], prior_checked=True) + for i in range(len(node.inputs)): + if (i in node.requires_input_bounds or not node.inputs[i].perturbed + or node.inputs[i].name in self.layers_with_constraint): + self.compute_intermediate_bounds( + node.inputs[i], prior_checked=True) node.prior_checked = True def compute_intermediate_bounds(self, node, prior_checked=False): if getattr(node, 'lower', None) is not None: + if node.name in self.layers_with_constraint: + node.clamp_interim_bounds() return logger.debug(f'Getting the bounds of {node}') @@ -922,17 +843,6 @@ def compute_intermediate_bounds(self, node, prior_checked=False): # computed from their input nodes by IBP # (such as BoundRelu, BoundNeg) logger.debug('IBP propagation for intermediate bounds on %s', node) - elif (isinstance(node, BoundReshape) - and hasattr(node.inputs[0], 'lower') - and hasattr(node.inputs[1], 'value')): - # TODO merge this with `check_IBP_intermediate` - # Node for input value. - val_input = node.inputs[0] - # Node for input parameter (e.g., shape, permute) - arg_input = node.inputs[1] - node.lower = node.forward(val_input.lower, arg_input.value) - node.upper = node.forward(val_input.upper, arg_input.value) - node.interval = (node.lower, node.upper) else: # For the first linear layer, IBP can give the same tightness # as CROWN. @@ -973,16 +883,33 @@ def compute_intermediate_bounds(self, node, prior_checked=False): newC, reduced_dim, unstable_idx, unstable_size = sparse_C if unstable_idx is None or unstable_size > 0: - if self.return_A: - node.lower, node.upper, _ = self.backward_general( - C=newC, node=node, unstable_idx=unstable_idx, - unstable_size=unstable_size) - else: - # Compute backward bounds only when there are unstable - # neurons, or when we don't know which neurons are unstable. - node.lower, node.upper = self.backward_general( - C=newC, node=node, unstable_idx=unstable_idx, - unstable_size=unstable_size) + apply_output_constraints_to = ( + self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] + ) + # Special case for BoundRelu when sparse intermediate bounds are disabled + # Currently sparse intermediate bounds are restricted to ReLU models only + skip = False + if unstable_idx is None: + if (len(node.output_name) == 1 + and isinstance(self[node.output_name[0]], + (BoundRelu, BoundSignMerge)) + and node.name in self.reference_bounds): + lower, upper = self.reference_bounds[node.name] + fully_stable = torch.logical_or(lower>=0, upper<=0).all() + if fully_stable: + node.lower, node.upper = lower, upper + skip = True + if not skip: + if self.return_A: + node.lower, node.upper, _ = self.backward_general( + node, newC, unstable_idx=unstable_idx, + apply_output_constraints_to=apply_output_constraints_to) + else: + # Compute backward bounds only when there are unstable + # neurons, or when we don't know which neurons are unstable. + node.lower, node.upper = self.backward_general( + node, newC, unstable_idx=unstable_idx, + apply_output_constraints_to=apply_output_constraints_to) if reduced_dim: self.restore_sparse_bounds( @@ -997,32 +924,18 @@ def compute_intermediate_bounds(self, node, prior_checked=False): # Initially, the reference bound and the computed bound can be # exactly the same when intermediate layer beta is 0. This will # prevent gradients flow. So we need a small guard here. - if self.intermediate_constr is not None: - # Intermediate layer beta is used. - # Note that we cannot just take the reference bounds if - # they are better - this makes alphas have zero gradients. - new_lower = 0.9 * ref_bounds[0] + 0.1 * node.lower - new_upper = 0.9 * ref_bounds[1] + 0.1 * node.upper - node.lower = torch.max(new_lower, node.lower) - node.upper = torch.min(new_upper, node.upper) - # Additionally, if the reference bounds say a neuron is - # stable, we always keep it. (FIXME: this is for ReLU only). - lower_stable = ref_bounds[0] >= 0. - node.lower[lower_stable] = ref_bounds[0][lower_stable] - upper_stable = ref_bounds[1] <= 0. - node.upper[upper_stable] = ref_bounds[1][upper_stable] - else: - # Set the intermediate layer bounds using reference bounds, - # always choosing the tighter one. - node.lower = ( - torch.max(ref_bounds[0], node.lower).detach() - - node.lower.detach() + node.lower) - node.upper = ( - node.upper - (node.upper.detach() - - torch.min(ref_bounds[1], node.upper).detach())) + # Set the intermediate layer bounds using reference bounds, + # always choosing the tighter one. + node.lower = (torch.max(ref_bounds[0], node.lower).detach() + - node.lower.detach() + node.lower) + node.upper = (node.upper - (node.upper.detach() + - torch.min(ref_bounds[1], node.upper).detach())) # Otherwise, we only use reference bounds to check which neurons # are unstable. + # prior constraint bounds + if node.name in self.layers_with_constraint: + node.clamp_interim_bounds() # FIXME (12/28): we should be consistent, and only use # node.interval, do not use node.lower or node.upper! node.interval = (node.lower, node.upper) @@ -1047,7 +960,7 @@ def compute_bounds( forward=False, bound_lower=True, bound_upper=True, reuse_ibp=False, reuse_alpha=False, return_A=False, needed_A_dict=None, final_node_name=None, average_A=False, - intermediate_layer_bounds=None, reference_bounds=None, + interm_bounds=None, reference_bounds=None, intermediate_constr=None, alpha_idx=None, aux_reference_bounds=None, need_A_only=False, cutter=None, decision_thresh=None, @@ -1126,7 +1039,7 @@ def compute_bounds( use this decision_thresh to dynamically optimize those domains that <= the threshold. - intermediate_layer_bounds: A dictionary of 2-element tuple/list + interm_bounds: A dictionary of 2-element tuple/list containing lower and upper bounds for intermediate layers. The dictionary keys should include the names of the layers whose bounds should be set without recomputation. The layer names can be @@ -1137,7 +1050,7 @@ def compute_bounds( you only need to set intermediate layer bounds for certain layers, then just include these layers' names in the dictionary. - reference_bounds: Format is similar to "intermediate_layer_bounds". + reference_bounds: Format is similar to "interm_bounds". However, these bounds are only used as a reference, and the bounds for intermediate layers will still be computed (e.g., using CROWN, IBP or other specified methods). The computed bounds will be @@ -1155,6 +1068,9 @@ def compute_bounds( is `True`, return a tuple of lower bound, upper bound, and `A` dictionary. """ + # This method only prepares everything by setting all required parameters. + # The main logic is located in `_compute_bounds_main`. It may be called + # repeatedly for CROWN optimizations. logger.debug(f'Compute bounds with {method}') if needed_A_dict is None: needed_A_dict = {} @@ -1163,60 +1079,26 @@ def compute_bounds( 'At least one of bound_lower and bound_upper must be True') # Several shortcuts. + compute_optimized = False method = method.lower() if method is not None else method if method == 'ibp': # Pure IBP bounds. - method = None - IBP = True + method, IBP = None, True elif method in ['ibp+backward', 'ibp+crown', 'crown-ibp']: - method = 'backward' - IBP = True + method, IBP = 'backward', True elif method == 'crown': method = 'backward' elif method == 'forward': forward = True elif method == 'forward+backward' or method == 'forward+crown': - method = 'backward' - forward = True + method, forward = 'backward', True elif method in ['crown-optimized', 'alpha-crown', 'forward-optimized']: # Lower and upper bounds need two separate rounds of optimization. if method == 'forward-optimized': method = 'forward' else: method = 'backward' - if bound_lower: - ret1 = self.get_optimized_bounds( - x=x, C=C, method=method, - intermediate_layer_bounds=intermediate_layer_bounds, - reference_bounds=reference_bounds, bound_lower=bound_lower, - bound_upper=False, return_A=return_A, - aux_reference_bounds=aux_reference_bounds, - needed_A_dict=needed_A_dict, - final_node_name=final_node_name, - cutter=cutter, decision_thresh=decision_thresh) - if bound_upper: - ret2 = self.get_optimized_bounds( - x=x, C=C, method=method, - intermediate_layer_bounds=intermediate_layer_bounds, - reference_bounds=reference_bounds, bound_lower=False, - bound_upper=bound_upper, return_A=return_A, - aux_reference_bounds=aux_reference_bounds, - needed_A_dict=needed_A_dict, - final_node_name=final_node_name, - cutter=cutter, decision_thresh=decision_thresh) - if bound_lower and bound_upper: - if return_A: - # Needs to merge the A dictionary. - lA_dict = ret1[2] - uA_dict = ret2[2] - merged_A = self.merge_A_dict(lA_dict, uA_dict) - return ret1[0], ret2[1], merged_A - else: - return ret1[0], ret2[1] - elif bound_lower: - return ret1 # ret1[1] is None. - elif bound_upper: - return ret2 # ret2[0] is None. + compute_optimized = True if reference_bounds is None: reference_bounds = {} @@ -1234,52 +1116,167 @@ def compute_bounds( A_dict = {} if return_A else None if x is not None: - self._set_input( - *x, intermediate_layer_bounds=intermediate_layer_bounds) + if isinstance(x, torch.Tensor): + x = (x,) + self.set_input(*x, interm_bounds=interm_bounds) - if IBP and method is None and reuse_ibp: - # directly return the previously saved ibp bounds - return self.ibp_lower, self.ibp_upper - root = [self[name] for name in self.root_name] - batch_size = root[0].value.shape[0] + roots = self.roots() + batch_size = roots[0].value.shape[0] dim_in = 0 - for i in range(len(root)): - value = root[i].forward() - if getattr(root[i], 'perturbation', None) is not None: - ret_init = root[i].perturbation.init( + for i in range(len(roots)): + value = roots[i].forward() + if getattr(roots[i], 'perturbation', None) is not None: + ret_init = roots[i].perturbation.init( value, aux=aux, forward=forward) - root[i].linear, root[i].center, root[i].aux = ret_init + roots[i].linear, roots[i].center, roots[i].aux = ret_init # This input/parameter has perturbation. # Create an interval object. - root[i].interval = Interval( - root[i].linear.lower, root[i].linear.upper, - ptb=root[i].perturbation) + roots[i].interval = Interval( + roots[i].linear.lower, roots[i].linear.upper, + ptb=roots[i].perturbation) if forward: - root[i].dim = root[i].linear.lw.shape[1] - dim_in += root[i].dim + roots[i].dim = roots[i].linear.lw.shape[1] + dim_in += roots[i].dim else: # This input/parameter does not has perturbation. # Use plain tuple defaulting to Linf perturbation. - root[i].interval = (value, value) - root[i].forward_value = root[i].value = value - root[i].center = root[i].lower = root[i].upper = value + roots[i].interval = (value, value) + roots[i].forward_value = roots[i].value = value + roots[i].center = roots[i].lower = roots[i].upper = value - root[i].lower, root[i].upper = root[i].interval + roots[i].lower, roots[i].upper = roots[i].interval if forward: - self.init_forward(root, dim_in) + self.init_forward(roots, dim_in) + + for n in self.nodes(): + if isinstance(n, BoundRelu): + for node in n.inputs: + if isinstance(node, BoundConv): + # whether this Conv is followed by a ReLU + node.relu_followed = True + + # Inject update mask inside the activations + # update_mask: None or bool tensor([batch_size]) + # If set to a tensor, only update the alpha and beta of selected + # element (with element=1). + n.alpha_beta_update_mask = update_mask + + final = (self.final_node() if final_node_name is None + else self[final_node_name]) + # BFS to find out whether each node is used given the current final node + self._set_used_nodes(final) - final = self.final_node( - ) if final_node_name is None else self[final_node_name] + # FIXME clean + self.use_forward = forward + self.batch_size = batch_size + self.dim_in = dim_in + self.return_A = return_A + self.A_dict = A_dict + self.needed_A_dict = needed_A_dict + self.intermediate_constr = intermediate_constr + self.reference_bounds = reference_bounds + self.aux_reference_bounds = aux_reference_bounds + self.final_node_name = final.name + + if compute_optimized: + kwargs = dict(x=x, C=C, method=method, interm_bounds=interm_bounds, + reference_bounds=reference_bounds, return_A=return_A, + aux_reference_bounds=aux_reference_bounds, + needed_A_dict=needed_A_dict, + final_node_name=final_node_name, + cutter=cutter, decision_thresh=decision_thresh) + if bound_upper: + ret2 = self._get_optimized_bounds(bound_side='upper', **kwargs) + if bound_lower: + ret1 = self._get_optimized_bounds(bound_side='lower', **kwargs) + if bound_lower and bound_upper: + if return_A: + # Needs to merge the A dictionary. + return ret1[0], ret2[1], self.merge_A_dict(ret1[2], ret2[2]) + else: + return ret1[0], ret2[1] + elif bound_lower: + return ret1 # ret1[1] is None. + elif bound_upper: + return ret2 # ret2[0] is None. + + + return self._compute_bounds_main(C=C, + method=method, + IBP=IBP, + bound_lower=bound_lower, + bound_upper=bound_upper, + reuse_ibp=reuse_ibp, + reuse_alpha=reuse_alpha, + average_A=average_A, + alpha_idx=alpha_idx, + need_A_only=need_A_only, + update_mask=update_mask) + + def save_intermediate(self, save_path=None): + r"""A function for saving intermediate bounds. + + Please call this function after `compute_bounds`, or it will output + IBP bounds by default. + + Args: + save_path (str, default `None`): If `None`, the intermediate bounds + will not be saved, or it will be saved at the designated path. + + Returns: + save_dict (dict): Return a dictionary of lower and upper bounds, with + the key being the name of the layer. + """ + save_dict = OrderedDict() + for node in self.nodes(): + if not hasattr(node, 'interval'): + ibp_lower, ibp_upper = self.IBP_general(node, + delete_bounds_after_use=True) + dim_output = int(prod(node.output_shape[1:])) + C = torch.eye(dim_output, device=self.device).expand( + self.batch_size, dim_output, dim_output) + crown_lower, crown_upper = self.backward_general(node, C=C) + save_dict[node.name] = ( + torch.max(crown_lower, ibp_lower), + torch.min(crown_upper, ibp_upper)) + else: + save_dict[node.name] = (node.lower, node.upper) + + if save_path is not None: + torch.save(save_dict, save_path) + return save_dict + + def _compute_bounds_main(self, C=None, method='backward', IBP=False, + bound_lower=True, bound_upper=True, reuse_ibp=False, + reuse_alpha=False, average_A=False, alpha_idx=None, + need_A_only=False, update_mask=None): + """The core implementation of compute_bounds. + + Seperated because compute_bounds may call _get_optimized_bounds which + repeatedly calls this method. Otherwise, the preprocessing done in + compute_bounds would be executed for each iteration. + """ + + final = (self.final_node() if self.final_node_name is None + else self[self.final_node_name]) logger.debug(f'Final node {final.__class__.__name__}({final.name})') + if IBP and method is None and reuse_ibp: + # directly return the previously saved ibp bounds + return self.ibp_lower, self.ibp_upper + if IBP: self.ibp_lower, self.ibp_upper = self.IBP_general(node=final, C=C) if method is None: return self.ibp_lower, self.ibp_upper + # TODO: if compute_bounds is called with a method that causes alphas to be + # optimized, C will be allocated in each iteration. We could allocate it once + # in compute_bounds, but e.g. `IBP_general` and code in `_get_optimized_bounds` + # relies on the fact that it can be None if C is None: # C is an identity matrix by default if final.output_shape is None: @@ -1288,64 +1285,35 @@ def compute_bounds( dim_output = int(prod(final.output_shape[1:])) # TODO: use an eyeC object here. C = torch.eye(dim_output, device=self.device).expand( - batch_size, dim_output, dim_output) + self.batch_size, dim_output, dim_output) # Reuse previously saved alpha values, # even if they are not optimized now + # This must be done here instead of `compute_bounds`, as other code might change + # it (e.g. `_get_optimized_bounds`) if reuse_alpha: - for node in self.optimizable_activations: - node.opt_reuse() - else: - for node in self.optimizable_activations: - node.opt_no_reuse() - - # Inject update mask inside the activations - # update_mask: None or bool tensor([batch_size]) - # If set to a tensor, only update the alpha and beta of selected - # element (with element=1). - - if update_mask is None: - for node in self.optimizable_activations: - node.clean_alpha_beta_update_mask() + self.opt_reuse() else: - for node in self.optimizable_activations: - node.set_alpha_beta_update_mask(update_mask) - - for n in self._modules.values(): - # Check whether all prior intermediate bounds already exist - n.prior_checked = False - if isinstance(i, BoundRelu): - for node in i.inputs: - if isinstance(node, BoundConv): - # whether this Conv is followed by a ReLU - node.relu_followed = True + self.opt_no_reuse() - # BFS to find out whether each node is used given the current final node - self._set_used_nodes(final) - - # FIXME clean - self.use_forward = forward - self.root = root - self.batch_size = batch_size - self.dim_in = dim_in - self.return_A = return_A - self.A_dict = A_dict - self.needed_A_dict = needed_A_dict - self.intermediate_constr = intermediate_constr - self.reference_bounds = reference_bounds - self.aux_reference_bounds = aux_reference_bounds - self.final_node_name = final.name + for node in self.nodes(): + # All nodes may need to be recomputed + node.prior_checked = False self.check_prior_bounds(final) if method == 'backward': + apply_output_constraints_to = ( + self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] + ) # This is for the final output bound. # No need to pass in intermediate layer beta constraints. ret = self.backward_general( - C=C, node=final, + final, C, bound_lower=bound_lower, bound_upper=bound_upper, average_A=average_A, need_A_only=need_A_only, - unstable_idx=alpha_idx, update_mask=update_mask) + unstable_idx=alpha_idx, update_mask=update_mask, + apply_output_constraints_to=apply_output_constraints_to) # FIXME when C is specified, lower and upper should not be saved to # final.lower and final.upper, because they are not the bounds for # the node. @@ -1360,7 +1328,7 @@ def _set_used_nodes(self, final): if final.name != self.last_final_node: self.last_final_node = final.name self.used_nodes = [] - for i in self._modules.values(): + for i in self.nodes(): i.used = False final.used = True queue = deque([final]) @@ -1373,32 +1341,29 @@ def _set_used_nodes(self, final): queue.append(n_pre) # Based on "used" and "perturbed" properties, find out which # layer requires intermediate layer bounds. - self.layers_requiring_bounds = ( - self.get_layers_requiring_bounds()) - - def add_intermediate_perturbation(self, node, perturbation): - """Add perturbation to an intermediate node and it is treated as an - independent node in bound computation.""" - node.perturbation = perturbation - node.perturbed = True - # NOTE This change is currently inreversible - if not node.name in self.root_name: - self.root_name.append(node.name) + self.layers_requiring_bounds = self.get_layers_requiring_bounds() from .interval_bound import ( IBP_general, _IBP_loss_fusion, check_IBP_intermediate, check_IBP_first_linear) from .forward_bound import ( - forward_general, forward_general_dynamic, init_forward) + forward_general, forward_general_dynamic, forward_refinement, init_forward) from .backward_bound import ( - backward_general, get_sparse_C, check_optimized_variable_sparsity, - restore_sparse_bounds, get_alpha_crown_start_nodes, - get_unstable_locations, batched_backward) - from .optimized_bounds import get_optimized_bounds, init_slope - from .beta_crown import ( - beta_bias, save_best_intermediate_betas, - print_optimized_beta) - from .jacobian import augment_gradient_graph, compute_jacobian_bounds - - - from .solver_module import build_solver_module, _build_solver_input, _build_solver_general + backward_general, get_sparse_C, concretize, + check_optimized_variable_sparsity, restore_sparse_bounds, + get_alpha_crown_start_nodes, get_unstable_locations, batched_backward, + _preprocess_C) + from .output_constraints import backward_general_with_output_constraint + from .optimized_bounds import ( + _get_optimized_bounds, init_alpha, update_best_beta, + opt_reuse, opt_no_reuse, _to_float64, _to_default_dtype) + from .beta_crown import (beta_crown_backward_bound, reset_beta, set_beta, + set_beta_cuts, get_split_nodes) + from .jacobian import (augment_gradient_graph, compute_jacobian_bounds, + _expand_jacobian) + from .optimize_graph import _optimize_graph + from .edit_graph import add_nodes, add_input_node, delete_node, replace_node + + + from .solver_module import ( + build_solver_module, _build_solver_input, _build_solver_general, _reset_solver_vars) diff --git a/auto_LiRPA/bound_multi_gpu.py b/auto_LiRPA/bound_multi_gpu.py index 057ef42..cd77386 100644 --- a/auto_LiRPA/bound_multi_gpu.py +++ b/auto_LiRPA/bound_multi_gpu.py @@ -110,14 +110,14 @@ def get_property(model, node_class=None, att_name=None, node_name=None): # `BoundedModule` type rather than bound nodes. for node in model._modules.values(): if node.name == node_name: - return getattr(node, att_name) + return getattr(node, att_name) else: # Find node by class for _, node in model.named_modules(): # Find the Exp neuron in computational graph if isinstance(node, node_class): return getattr(node, att_name) - + def state_dict(self, destination=None, prefix='', keep_vars=False): # add 'module.' here before each keys in self.module.state_dict() if needed return self.module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) diff --git a/auto_LiRPA/bound_op_map.py b/auto_LiRPA/bound_op_map.py index 7c490fa..c2f8a44 100644 --- a/auto_LiRPA/bound_op_map.py +++ b/auto_LiRPA/bound_op_map.py @@ -1,5 +1,6 @@ -from .bound_ops import Bound, BoundLinear, BoundPrimConstant -from .bound_ops import BoundReluGrad, BoundConv2dGrad, BoundSqr +from .bound_ops import ( + Bound, BoundLinear, BoundPrimConstant, BoundGELU, BoundReluGrad, + BoundConv2dGrad, BoundSqr, BoundJacobianOP) bound_op_map = { 'onnx::Gemm': BoundLinear, @@ -7,6 +8,8 @@ 'grad::Relu': BoundReluGrad, 'grad::Conv2d': BoundConv2dGrad, 'grad::Sqr': BoundSqr, + 'grad::jacobian': BoundJacobianOP, + 'custom::Gelu': BoundGELU, } def register_custom_op(op_name: str, bound_obj: Bound) -> None: diff --git a/auto_LiRPA/bounded_tensor.py b/auto_LiRPA/bounded_tensor.py index 8b44413..40f8a48 100644 --- a/auto_LiRPA/bounded_tensor.py +++ b/auto_LiRPA/bounded_tensor.py @@ -1,10 +1,9 @@ import copy -import torch import torch.nn as nn -from torch import Tensor as Tensor - +from torch import Tensor import torch._C as _C + class BoundedTensor(Tensor): @staticmethod # We need to override the __new__ method since Tensor is a C class diff --git a/auto_LiRPA/edit_graph.py b/auto_LiRPA/edit_graph.py new file mode 100644 index 0000000..158f38b --- /dev/null +++ b/auto_LiRPA/edit_graph.py @@ -0,0 +1,62 @@ +"""Edit the computational graph in BoundedModule.""" + +from auto_LiRPA.bound_ops import Bound + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule + + +# Make sure the nodes already have `name` and `input_name` +def add_nodes(self: 'BoundedModule', nodes): + # TODO check duplicate names + nodes = [(node if isinstance(node, Bound) else node.bound_node) + for node in nodes] + for node in nodes: + self._modules[node.name] = node + node.output_name = [] + if len(node.inputs) == 0: + self.root_names.append(node.name) + for node in nodes: + for l_pre in node.inputs: + l_pre.output_name.append(node.name) + if (getattr(node, 'has_constraint', False) and + node.name not in self.layers_with_constraint): + self.layers_with_constraint.append(node.name) + + +def add_input_node(self: 'BoundedModule', node, index=None): + self.add_nodes([node]) + self.input_name.append(node.name) + # default value for input_index + if index == 'auto': + index = max([0] + [(i + 1) + for i in self.input_index if i is not None]) + self.input_index.append(index) + + +def delete_node(self: 'BoundedModule', node): + for node_inp in node.inputs: + node_inp.output_name.pop(node_inp.output_name.index(node.name)) + self._modules.pop(node.name) + # TODO Create a list to contain all such lists such as + # "relus" and "optimizable_activations" + self.relus = [ + item for item in self.relus if item != node] + self.optimizable_activations = [ + item for item in self.optimizable_activations if item != node] + + +def replace_node(self: 'BoundedModule', node_old, node_new): + assert node_old != node_new + for node in self.nodes(): + for i in range(len(node.inputs)): + if node.inputs[i] == node_old: + node.inputs[i] = node_new + node_new.output_name += node_old.output_name + if self.final_name == node_old.name: + self.final_name = node_new.name + for i in range(len(self.output_name)): + if self.output_name[i] == node_old.name: + self.output_name[i] = node_new.name + self.delete_node(node_old) diff --git a/auto_LiRPA/eps_scheduler.py b/auto_LiRPA/eps_scheduler.py index 9332855..4320483 100644 --- a/auto_LiRPA/eps_scheduler.py +++ b/auto_LiRPA/eps_scheduler.py @@ -47,10 +47,10 @@ def update_loss(self, new_loss): def train(self): self.is_training = True - + def eval(self): self.is_training = False - + # Set how many batches in an epoch def set_epoch_length(self, epoch_length): self.epoch_length = epoch_length @@ -133,8 +133,8 @@ def __init__(self, max_eps, opt_str): def __repr__(self): return ''.format( - self.epoch_start_eps, self.epoch_end_eps) - + self.epoch_start_eps, self.epoch_end_eps) + def step_epoch(self, verbose = True): self.epoch += 1 self.batch = 0 @@ -148,7 +148,7 @@ def step_epoch(self, verbose = True): self.epoch_start_eps = min(eps_epoch * eps_epoch_step, self.max_eps) self.epoch_end_eps = min((eps_epoch + 1) * eps_epoch_step, self.max_eps) else: - self.epoch_start_eps = max(0, + self.epoch_start_eps = max(0, self.max_eps - ((eps_epoch - self.schedule_length_half) * eps_epoch_step)) self.epoch_end_eps = max(0, self.epoch_start_eps - eps_epoch_step) self.eps = self.epoch_start_eps @@ -172,7 +172,7 @@ def __init__(self, max_eps, opt_str): assert self.mid_point >= 0. and self.mid_point <= 1. self.batch = 0 - + # Set how many batches in an epoch def set_epoch_length(self, epoch_length): if self.epoch_length != self.epoch_length: @@ -183,11 +183,11 @@ def set_epoch_length(self, epoch_length): def step_epoch(self, verbose = True): super(SmoothedScheduler, self).step_epoch() - # FIXME + # FIXME if verbose == False: for i in range(self.epoch_length): self.step_batch() - + # Smooth schedule that slowly morphs into a linear schedule. # Code is based on DeepMind's IBP implementation: # https://github.com/deepmind/interval-bound-propagation/blob/2c1a56cb0497d6f34514044877a8507c22c1bd85/interval_bound_propagation/src/utils.py#L84 @@ -230,7 +230,7 @@ def __init__(self, max_eps, opt_str): self.small_loss_thresh = float(self.params.get('small_loss_thresh', 0.05)) self.epoch = 0 self.eps_step = self.min_eps_step - + def step_batch(self): if self.eps < self.max_eps and self.epoch >= self.schedule_start and self.is_training: if self.loss != self.loss or self.prev_loss != self.prev_loss: @@ -275,4 +275,3 @@ def step_batch(self): plt.grid() plt.tight_layout() plt.savefig('epsilon.pdf') - diff --git a/auto_LiRPA/forward_bound.py b/auto_LiRPA/forward_bound.py index f21066e..58cc978 100644 --- a/auto_LiRPA/forward_bound.py +++ b/auto_LiRPA/forward_bound.py @@ -5,10 +5,16 @@ from .linear_bound import LinearBound from .perturbations import PerturbationLpNorm +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule + import sys sys.setrecursionlimit(1000000) -def forward_general(self, C=None, node=None, concretize=False, offset=0): + +def forward_general(self: 'BoundedModule', C=None, node=None, concretize=False, + offset=0): if self.bound_opts['dynamic_forward']: return self.forward_general_dynamic(C, node, concretize, offset) @@ -56,20 +62,21 @@ def forward_general(self, C=None, node=None, concretize=False, offset=0): if concretize: if lw is not None or uw is not None: + roots = self.roots() prev_dim_in = 0 batch_size = lw.shape[0] assert (lw.ndim > 1) lA = lw.reshape(batch_size, self.dim_in, -1).transpose(1, 2) uA = uw.reshape(batch_size, self.dim_in, -1).transpose(1, 2) - for i in range(len(self.root)): - if hasattr(self.root[i], 'perturbation') and self.root[i].perturbation is not None: - _lA = lA[:, :, prev_dim_in : (prev_dim_in + self.root[i].dim)] - _uA = uA[:, :, prev_dim_in : (prev_dim_in + self.root[i].dim)] - lower = lower + self.root[i].perturbation.concretize( - self.root[i].center, _lA, sign=-1, aux=self.root[i].aux).view(lower.shape) - upper = upper + self.root[i].perturbation.concretize( - self.root[i].center, _uA, sign=+1, aux=self.root[i].aux).view(upper.shape) - prev_dim_in += self.root[i].dim + for i in range(len(roots)): + if hasattr(roots[i], 'perturbation') and roots[i].perturbation is not None: + _lA = lA[:, :, prev_dim_in : (prev_dim_in + roots[i].dim)] + _uA = uA[:, :, prev_dim_in : (prev_dim_in + roots[i].dim)] + lower = lower + roots[i].perturbation.concretize( + roots[i].center, _lA, sign=-1, aux=roots[i].aux).view(lower.shape) + upper = upper + roots[i].perturbation.concretize( + roots[i].center, _uA, sign=+1, aux=roots[i].aux).view(upper.shape) + prev_dim_in += roots[i].dim linear.lower, linear.upper = lower, upper if C is None: @@ -85,12 +92,12 @@ def forward_general(self, C=None, node=None, concretize=False, offset=0): need_refinement = True break if need_refinement: - forward_refinement(self, node) + self.forward_refinement(node) return lower, upper -def forward_general_dynamic( - self, C=None, node=None, concretize=False, offset=0): +def forward_general_dynamic(self: 'BoundedModule', C=None, node=None, + concretize=False, offset=0): max_dim = self.bound_opts['forward_max_dim'] if C is None: @@ -132,7 +139,6 @@ def forward_general_dynamic( if not node.perturbed: if not hasattr(node, 'lower'): node.lower = node.upper = self.get_forward_value(node) - raise NotImplementedError if concretize: return node.lower, node.upper else: @@ -212,10 +218,9 @@ def forward_general_dynamic( return linear -def clean_memory(self, node): +def clean_memory(self: 'BoundedModule', node): """ Remove linear bounds that are no longer needed. """ # TODO add an option to retain these bounds - for inp in node.inputs: if hasattr(inp, 'linear') and inp.linear is not None: clean = True @@ -230,7 +235,7 @@ def clean_memory(self, node): delattr(inp, 'linear') -def forward_refinement(self, node): +def forward_refinement(self: 'BoundedModule', node): """ Refine forward bounds with backward bound propagation (only refine unstable positions). """ unstable_size_before = torch.logical_and(node.lower < 0, node.upper > 0).sum() @@ -250,51 +255,51 @@ def forward_refinement(self, node): # TODO also update linear bounds? -def init_forward(self, root, dim_in): +def init_forward(self: 'BoundedModule', roots, dim_in): if dim_in == 0: raise ValueError("At least one node should have a specified perturbation") prev_dim_in = 0 - # Assumption: root[0] is the input node which implies batch_size - batch_size = root[0].value.shape[0] + # Assumption: roots[0] is the input node which implies batch_size + batch_size = roots[0].value.shape[0] dynamic = self.bound_opts['dynamic_forward'] - for i in range(len(root)): - if hasattr(root[i], 'perturbation') and root[i].perturbation is not None: - shape = root[i].linear.lw.shape + for i in range(len(roots)): + if hasattr(roots[i], 'perturbation') and roots[i].perturbation is not None: + shape = roots[i].linear.lw.shape if dynamic: if shape[1] != dim_in: raise NotImplementedError('Dynamic forward bound is not supported yet when there are multiple perturbed inputs.') - ptb = root[i].perturbation + ptb = roots[i].perturbation if (type(ptb) != PerturbationLpNorm or ptb.norm < np.inf or ptb.x_L is None or ptb.x_U is None): raise NotImplementedError( 'For dynamic forward bounds, only Linf (box) perturbations are supported, and x_L and x_U must be explicitly provided.') - root[i].linear.x_L = ( + roots[i].linear.x_L = ( ptb.x_L_sparse.view(batch_size, -1) if ptb.sparse else ptb.x_L.view(batch_size, -1)) - root[i].linear.x_U = ( + roots[i].linear.x_U = ( ptb.x_U_sparse.view(batch_size, -1) if ptb.sparse else ptb.x_U.view(batch_size, -1)) else: - lw = torch.zeros(shape[0], dim_in, *shape[2:]).to(root[i].linear.lw) - lw[:, prev_dim_in:(prev_dim_in+shape[1])] = root[i].linear.lw - if root[i].linear.lw.data_ptr() == root[i].linear.uw.data_ptr(): + lw = torch.zeros(shape[0], dim_in, *shape[2:]).to(roots[i].linear.lw) + lw[:, prev_dim_in:(prev_dim_in+shape[1])] = roots[i].linear.lw + if roots[i].linear.lw.data_ptr() == roots[i].linear.uw.data_ptr(): uw = lw else: - uw = torch.zeros(shape[0], dim_in, *shape[2:]).to(root[i].linear.uw) - uw[:, prev_dim_in:(prev_dim_in+shape[1])] = root[i].linear.uw - root[i].linear.lw = lw - root[i].linear.uw = uw + uw = torch.zeros(shape[0], dim_in, *shape[2:]).to(roots[i].linear.uw) + uw[:, prev_dim_in:(prev_dim_in+shape[1])] = roots[i].linear.uw + roots[i].linear.lw = lw + roots[i].linear.uw = uw if i >= self.num_global_inputs: - root[i].forward_value = root[i].forward_value.unsqueeze(0).repeat( + roots[i].forward_value = roots[i].forward_value.unsqueeze(0).repeat( *([batch_size] + [1] * self.forward_value.ndim)) prev_dim_in += shape[1] else: - b = fv = root[i].forward_value + b = fv = roots[i].forward_value shape = fv.shape - if root[i].from_input: + if roots[i].from_input: w = torch.zeros(shape[0], dim_in, *shape[1:], device=self.device) warnings.warn(f'Creating a LinearBound with zero weights with shape {w.shape}') else: w = None - root[i].linear = LinearBound(w, b, w, b, b, b) - root[i].lower = root[i].upper = b + roots[i].linear = LinearBound(w, b, w, b, b, b) + roots[i].lower = roots[i].upper = b diff --git a/auto_LiRPA/interval_bound.py b/auto_LiRPA/interval_bound.py index a4ebcf5..2033aba 100644 --- a/auto_LiRPA/interval_bound.py +++ b/auto_LiRPA/interval_bound.py @@ -1,8 +1,16 @@ import torch from .bound_ops import * +from .utils import logger +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule -def IBP_general(self, node=None, C=None, delete_bounds_after_use=False): + +def IBP_general(self: 'BoundedModule', node=None, C=None, + delete_bounds_after_use=False): + + logger.debug('IBP for %s', node) def _delete_unused_bounds(node_list): """Delete bounds from input layers after use to save memory. Used when @@ -18,9 +26,9 @@ def _delete_unused_bounds(node_list): if res is not None: return res - if not node.perturbed and hasattr(node, 'forward_value'): - node.lower, node.upper = node.interval = ( - node.forward_value, node.forward_value) + if not node.perturbed: + fv = self.get_forward_value(node) + node.lower, node.upper = node.interval = (fv, fv) to_be_deleted_bounds = [] if not hasattr(node, 'interval'): @@ -56,12 +64,11 @@ def _delete_unused_bounds(node_list): _delete_unused_bounds(to_be_deleted_bounds) return node.interval -def _IBP_loss_fusion(self, node, C): +def _IBP_loss_fusion(self: 'BoundedModule', node, C): """Merge BoundLinear, BoundGatherElements and BoundSub. Improvement when loss fusion is used in training. """ - # not using loss fusion if not self.bound_opts.get('loss_fusion', False): return None @@ -105,29 +112,40 @@ def _IBP_loss_fusion(self, node, C): return None -def check_IBP_intermediate(self, node): +def check_IBP_intermediate(self: 'BoundedModule', node): """ Check if we use IBP bounds to compute intermediate bounds on this node. - We check if we can get bounds by only visiting operators in - `self.ibp_intermediate`. Currently, assume all eligible operators have - exactly one input. + Currently, assume all eligible operators have exactly one input. """ + if (isinstance(node, BoundReshape) + and hasattr(node.inputs[0], 'lower') + and hasattr(node.inputs[1], 'value')): + # Node for input value. + val_input = node.inputs[0] + # Node for input parameter (e.g., shape, permute) + arg_input = node.inputs[1] + node.lower = node.forward(val_input.lower, arg_input.value) + node.upper = node.forward(val_input.upper, arg_input.value) + node.interval = (node.lower, node.upper) + return True nodes = [] - while not hasattr(node, 'lower') or not hasattr(node, 'upper'): - if type(node) not in self.ibp_intermediate: + while (getattr(node, 'lower', None) is None + or getattr(node, 'upper', None) is None): + if not node.ibp_intermediate: return False assert len(node.inputs) == 1, ( - 'Nodes in self.ibp_intermediate cannot have more than one input') + 'Nodes with ibp_intermediate=True cannot have more than one input') nodes.append(node) node = node.inputs[0] # FIXME: this cannot handle multiple inputs. nodes.reverse() for n in nodes: n.interval = self.IBP_general(n) + return True -def check_IBP_first_linear(self, node): +def check_IBP_first_linear(self: 'BoundedModule', node): """Here we avoid creating a big C matrix in the first linear layer. Disable this optimization when we have beta for intermediate layer bounds. Disable this optimization when we need the A matrix of the first nonlinear diff --git a/auto_LiRPA/jacobian.py b/auto_LiRPA/jacobian.py index aaf6ab5..b9d498a 100644 --- a/auto_LiRPA/jacobian.py +++ b/auto_LiRPA/jacobian.py @@ -2,15 +2,170 @@ import torch import numpy as np -from auto_LiRPA.bound_ops import BoundInput, BoundParams, BoundAdd -from auto_LiRPA.bound_ops import GradNorm, JVP -from auto_LiRPA.utils import get_spec_matrix, Flatten +from auto_LiRPA.bound_ops import JacobianOP, GradNorm # pylint: disable=unused-import +from auto_LiRPA.bound_ops import ( + BoundInput, BoundParams, BoundAdd, BoundRelu, BoundJacobianInit, JVP, + BoundJacobianOP) +from auto_LiRPA.utils import Flatten, logger, prod, get_spec_matrix from collections import deque +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule -def augment_gradient_graph(self, dummy_input, norm=None, vector=None): - """Augment the computational graph with gradient computation.""" +def _expand_jacobian(self): + self.jacobian_start_nodes = [] + for node in list(self.nodes()): + if isinstance(node, BoundJacobianOP): + self.jacobian_start_nodes.append(node.inputs[0]) + expand_jacobian_node(self, node) + if self.jacobian_start_nodes: + # Disable unstable options + self.bound_opts.update({ + 'sparse_intermediate_bounds': False, + 'sparse_conv_intermediate_bounds': False, + 'sparse_intermediate_bounds_with_ibp': False, + 'sparse_features_alpha': False, + 'sparse_spec_alpha': False, + }) + for node in self.nodes(): + if isinstance(node, BoundRelu): + node.use_sparse_spec_alpha = node.use_sparse_features_alpha = False + + +def expand_jacobian_node(self, jacobian_node): + """New API for converting a graph with Jacobian + + Based on the old API `augment_gradient_graph`. + """ + logger.info(f'Expanding Jacobian node {jacobian_node}') + + output_node = jacobian_node.inputs[0] + input_node = jacobian_node.inputs[1] + batch_size = output_node.output_shape[0] + output_dim = prod(output_node.output_shape[1:]) + + # Gradient values in `grad` may not be accurate. We do not consider gradient + # accumulation from multiple succeeding nodes. We only want the shapes but + # not the accurate values. + grad = {} + # Dummy values in grad_start + grad_start = torch.ones(batch_size, output_dim, output_dim, device=self.device) + grad[output_node.name] = grad_start + input_node_found = False + + # First BFS pass: traverse the graph, count degrees, and build gradient + # layers. + # Degrees of nodes. + degree = {} + # Original layer for gradient computation. + layer_grad = {} + # Input nodes in gradient computation in back propagation. + input_nodes = {} + # Dummy input values for gradient computation received. + grad_input = {} + # Extra nodes as arguments used for gradient computation. + # They must match the order in grad_input. + grad_extra_nodes = {} + + degree[output_node.name] = 0 + queue = deque([output_node]) + while len(queue) > 0: + node = queue.popleft() + grad_extra_nodes[node.name] = [] + input_nodes[node.name] = node.inputs + + if node == input_node: + input_node_found = True + layer_grad[node.name] = Flatten() + grad_input[node.name] = (grad[node.name],) + else: + ret = node.build_gradient_node(grad[node.name]) + node_grad, grad_input_, grad_extra_nodes_ = ret + layer_grad[node.name] = node_grad + grad_input[node.name] = grad_input_ + grad_extra_nodes[node.name] = grad_extra_nodes_ + + # Propagate gradients to the input nodes and update degrees. + grad_next = layer_grad[node.name](*grad_input[node.name]) + if isinstance(grad_next, torch.Tensor): + grad_next = [grad_next] + if not isinstance(node, BoundInput): + for i in range(len(grad_next)): + grad[input_nodes[node.name][i].name] = grad_next[i] + if not input_nodes[node.name][i].name in degree: + degree[input_nodes[node.name][i].name] = 0 + queue.append(input_nodes[node.name][i]) + degree[input_nodes[node.name][i].name] += 1 + + if not input_node_found: + raise RuntimeError('Input node not found') + + # Second BFS pass: build the backward computational graph + grad_node = {} + initial_name = f'/jacobian{output_node.name}{output_node.name}' + grad_node[output_node.name] = BoundJacobianInit(inputs=[output_node]) + grad_node[output_node.name].name = initial_name + self.add_nodes([grad_node[output_node.name]]) + queue = deque([output_node]) + while len(queue) > 0: + node = queue.popleft() + nodes_op, nodes_in, nodes_out, _ = self._convert_nodes( + layer_grad[node.name], grad_input[node.name]) + rename_dict = {} + assert isinstance(nodes_in[0], BoundInput) + rename_dict[nodes_in[0].name] = grad_node[node.name].name + for i in range(1, len(nodes_in)): + # Assume it's a parameter here + new_name = f'/jacobian{output_node.name}{node.name}/params{nodes_in[i].name}' + rename_dict[nodes_in[i].name] = new_name + for i in range(len(nodes_op)): + # intermediate nodes + if not nodes_op[i].name in rename_dict: + new_name = f'/jacobian{output_node.name}{node.name}/tmp{nodes_op[i].name}' + rename_dict[nodes_op[i].name] = new_name + if node == input_node: + assert len(nodes_out) == 1 + rename_dict[nodes_out[0].name] = f'/jacobian{output_node.name}' + else: + for i in range(len(nodes_out)): + assert not isinstance(node.inputs[i], BoundParams) + rename_dict[nodes_out[i].name] = f'/jacobian{output_node.name}{node.inputs[i].name}' + + self.rename_nodes(nodes_op, nodes_in, rename_dict) + # Replace input nodes + # grad_extra_nodes[node.name]: ReLU's input + input_nodes_replace = ( + [self._modules[nodes_in[0].name]] + grad_extra_nodes[node.name]) + for i in range(len(input_nodes_replace)): + for n in nodes_op: + for j in range(len(n.inputs)): + if n.inputs[j].name == nodes_in[i].name: + n.inputs[j] = input_nodes_replace[i] + self.add_nodes(nodes_op + nodes_in[len(input_nodes_replace):]) + + if node != input_node: + for i in range(len(nodes_out)): + if input_nodes[node.name][i].name in grad_node: + node_cur = grad_node[input_nodes[node.name][0].name] + node_add = BoundAdd( + attr=None, inputs=[node_cur, nodes_out[i]], + output_index=0, options={}) + node_add.name = f'{nodes_out[i].name}/add' + grad_node[input_nodes[node.name][0].name] = node_add + else: + grad_node[input_nodes[node.name][0].name] = nodes_out[i] + degree[input_nodes[node.name][i].name] -= 1 + if degree[input_nodes[node.name][i].name] == 0: + queue.append(input_nodes[node.name][i]) + else: + self.replace_node(jacobian_node, grad_node[node.name]) + + +def augment_gradient_graph(self: 'BoundedModule', dummy_input, norm=None, + vector=None): + """Augment the computational graph with gradient computation.""" device = dummy_input.device final_node = self.final_node() if final_node.forward_value is None: @@ -64,9 +219,6 @@ def augment_gradient_graph(self, dummy_input, norm=None, vector=None): if norm is None: layer_grad[node.name] = Flatten() else: - if norm != np.inf: - raise NotImplementedError( - 'Only inf norm is supported for now.') dual_norm = 1. / (1. - 1. / norm) if norm != 1 else np.inf layer_grad[node.name] = GradNorm(norm=dual_norm) grad_input[node.name] = (grad[node.name],) @@ -166,6 +318,9 @@ def augment_gradient_graph(self, dummy_input, norm=None, vector=None): 'sparse_features_alpha': False, 'sparse_spec_alpha': False, }) + for node in self.nodes(): + if isinstance(node, BoundRelu): + node.use_sparse_spec_alpha = node.use_sparse_features_alpha = False self.forward_final_name = self.final_name self.final_name = '/grad_norm' @@ -174,8 +329,8 @@ def augment_gradient_graph(self, dummy_input, norm=None, vector=None): return self -def compute_jacobian_bounds( - self, x, optimize=True, reduce=True, c_opt=None, labels=None): +def compute_jacobian_bounds(self: 'BoundedModule', x, optimize=True, + reduce=True, c_opt=None, labels=None): """Compute jacobian bounds on the pre-augmented graph. Args: @@ -193,13 +348,10 @@ def compute_jacobian_bounds( output dimensions, otherwise return the Jacobian bounds for all the output dimensions in a tensor. """ - assert 'jacobian' in self.bound_opts, ( 'Call augment_gradient_graph to augment the computational graph ' 'with the backward graph first') norm = self.bound_opts.get('jacobian', {}).get('norm', None) - assert norm is None or norm == np.inf, ( - 'Only Linf norm of Jacobian is supported for now.') num_classes = self[self.forward_final_name].output_shape[-1] @@ -227,7 +379,7 @@ def compute_jacobian_bounds( intermediate_bounds[node.name] = (node.lower, node.upper) lb, ub = self.compute_bounds( method='CROWN', x=(x,) + x_extra, bound_lower=norm is None, - intermediate_layer_bounds=intermediate_bounds) + interm_bounds=intermediate_bounds) if norm is not None: ret.append(ub.view(-1)) else: @@ -243,3 +395,37 @@ def compute_jacobian_bounds( lower = torch.concat(lower, dim=0) upper = torch.concat(upper, dim=0) return lower, upper + + +def compute_jacobian_bounds_new(self: 'BoundedModule', x, optimize=True, + optimize_output_node=None, + bound_lower=True, bound_upper=True): + """Compute jacobian bounds on the pre-augmented graph (new API).""" + + if isinstance(x, torch.Tensor): + x = (x,) + + if optimize: + if optimize_output_node is None: + if len(self.jacobian_start_nodes) == 1: + optimize_output_node = self.jacobian_start_nodes[0] + else: + raise NotImplementedError( + 'Multiple Jacobian nodes found.' + 'An output node for optimizable bounds (optimize_output_node) ' + 'must be specified explicitly') + self.compute_bounds( + method='CROWN-Optimized', + C=None, x=x, bound_upper=False, + final_node_name=optimize_output_node.name) + intermediate_bounds = {} + for node in self._modules.values(): + if hasattr(node, 'lower') and node.lower is not None: + intermediate_bounds[node.name] = (node.lower, node.upper) + else: + intermediate_bounds = None + lb, ub = self.compute_bounds( + method='CROWN', x=x, + bound_lower=bound_lower, bound_upper=bound_upper, + interm_bounds=intermediate_bounds) + return lb, ub diff --git a/auto_LiRPA/operators/__init__.py b/auto_LiRPA/operators/__init__.py index c2873b6..4ab2f51 100644 --- a/auto_LiRPA/operators/__init__.py +++ b/auto_LiRPA/operators/__init__.py @@ -5,7 +5,10 @@ from .activation_base import * from .activations import * from .nonlinear import * +from .relu import * +from .tanh import * from .bivariate import * +from .add_sub import * from .normalization import * from .shape import * from .reduce import * @@ -16,7 +19,10 @@ from .logical import * from .dropout import * from .dtype import * +from .trigonometric import * from .cut_ops import * from .gradient_bounds import * from .gradient_modules import * from .solver_utils import grb +from .resize import * +from .jacobian import * diff --git a/auto_LiRPA/operators/activation_base.py b/auto_LiRPA/operators/activation_base.py index 6d357c9..14a7a76 100644 --- a/auto_LiRPA/operators/activation_base.py +++ b/auto_LiRPA/operators/activation_base.py @@ -10,25 +10,28 @@ class BoundActivation(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.requires_input_bounds = [0] - self.relaxed = False self.use_default_ibp = True + self.splittable = False def _init_masks(self, x): self.mask_pos = x.lower >= 0 self.mask_neg = x.upper <= 0 self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg)) - def init_linear_relaxation(self, x, dim_opt=None): + def init_linear_relaxation(self, x): self._init_masks(x) self.lw = torch.zeros_like(x.lower) self.lb = self.lw.clone() self.uw = self.lw.clone() self.ub = self.lw.clone() - def add_linear_relaxation(self, mask, type, k, x0, y0): + def add_linear_relaxation(self, mask, type, k, x0, y0=None): + if y0 is None: + y0 = self.forward(x0) + if type == 'lower': w_out, b_out = self.lw, self.lb else: @@ -40,10 +43,8 @@ def add_linear_relaxation(self, mask, type, k, x0, y0): else: w_out.fill_(k) else: - if isinstance(k, Tensor): - w_out[..., mask] = k[..., mask].to(w_out) - else: - w_out[..., mask] = k + w_out[..., mask] = (k[..., mask].to(w_out) if isinstance(k, Tensor) + else k) if (not isinstance(x0, Tensor) and x0 == 0 and not isinstance(y0, Tensor) and y0 == 0): @@ -58,13 +59,11 @@ def add_linear_relaxation(self, mask, type, k, x0, y0): else: b_out[..., mask] = b[..., mask] - def bound_relax(self, x): + def bound_relax(self, x, init=False): return not_implemented_op(self, 'bound_relax') - def bound_backward(self, last_lA, last_uA, x): - if not self.relaxed: - self.init_linear_relaxation(x) - self.bound_relax(x) + def bound_backward(self, last_lA, last_uA, x, reduce_bias=True, **kwargs): + self.bound_relax(x, init=True) def _bound_oneside(last_A, sign=-1): if last_A is None: @@ -82,9 +81,11 @@ def _bound_oneside(last_A, sign=-1): b_pos = maybe_unfold_patches(b_pos, last_A) b_neg = maybe_unfold_patches(b_neg, last_A) if self.batch_dim == 0: - _A, _bias = multiply_by_A_signs(last_A, w_pos, w_neg, b_pos, b_neg) + _A, _bias = multiply_by_A_signs( + last_A, w_pos, w_neg, b_pos, b_neg, reduce_bias=reduce_bias) elif self.batch_dim == -1: # FIXME: why this is different from above? + assert reduce_bias mask = torch.gt(last_A, 0.).to(torch.float) _A = last_A * (mask * w_pos.unsqueeze(1) + (1 - mask) * w_neg.unsqueeze(1)) @@ -122,19 +123,19 @@ def bound_forward_b( return lb, ub def bound_forward(self, dim_in, x): - if not self.relaxed: - self.init_linear_relaxation(x) - self.bound_relax(x) + self.bound_relax(x, init=True) assert (x.lw is None) == (x.uw is None) dim = 1 if self.lw.ndim > 0 else 0 if x.lw is not None: - lw, uw = BoundActivation.bound_forward_w(self.lw, self.uw, x.lw, x.uw, dim) + lw, uw = BoundActivation.bound_forward_w( + self.lw, self.uw, x.lw, x.uw, dim) else: lw = uw = None - lb, ub = BoundActivation.bound_forward_b(self.lw, self.uw, self.lb, self.ub, x.lb, x.ub) + lb, ub = BoundActivation.bound_forward_b( + self.lw, self.uw, self.lb, self.ub, x.lb, x.ub) return LinearBound(lw, lb, uw, ub) @@ -142,10 +143,18 @@ def interval_propagate(self, *v): h_L, h_U = v[0][0], v[0][1] return self.forward(h_L), self.forward(h_U) + def get_split_mask(self, lower, upper, input_index): + """Return a mask to indicate if each neuron potentially needs a split. + + 0: Stable (linear) neuron; 1: unstable (nonlinear) neuron. + """ + return torch.ones_like(lower) + class BoundOptimizableActivation(BoundActivation): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) + self.optimizable = True # Stages: # * `init`: initializing parameters # * `opt`: optimizing parameters @@ -155,12 +164,10 @@ def __init__(self, attr, inputs, output_index, options): self.alpha = OrderedDict() # Save patch sizes during bound_backward() for each output_node. self.patch_size = {} - # Location of batch dimension in self.alpha. Must be set by children. - self.alpha_batch_dim = None - # A torch.bool mask of shape Tensor([batch_size]) that conditions the sample of alpha and beta to update + # A torch.bool mask of shape Tensor([batch_size]) that conditions the + # sample of alpha and beta to update # If set to None, update all samples # If not None, select those corresponding to 1 to update - self.alpha_beta_update_mask = None def opt_init(self): """Enter the stage for initializing bound optimization. Optimized bounds @@ -184,57 +191,82 @@ def opt_end(self): """ End optimizing bounds """ self.opt_stage = None + def clip_alpha(self): + pass + def init_opt_parameters(self, start_nodes): """ start_nodes: a list of starting nodes [(node, size)] during CROWN backward bound propagation""" + self.alpha = OrderedDict() + for start_node in start_nodes: + ns, size_s = start_node[:2] + # TODO do not give torch.Size + if isinstance(size_s, (torch.Size, list, tuple)): + size_s = prod(size_s) + self.alpha[ns] = self._init_opt_parameters_impl(size_s, name_start=ns) + + def _init_opt_parameters_impl(self, size_spec, name_start=None): + """Implementation of init_opt_parameters for each start_node.""" raise NotImplementedError - def clip_alpha_(self): - pass - def init_linear_relaxation(self, x, dim_opt=None): self._init_masks(x) # The first dimension of size 2 is used for lA and uA respectively, # when computing intermediate bounds. if self.opt_stage in ['opt', 'reuse'] and dim_opt is not None: - # For optimized bounds, we have independent lw for each output dimension for bound optimization. + # For optimized bounds, we have independent lw for each output + # dimension for bound optimization. # If the output layer is a fully connected layer, len(dim_opt) = 1. - # If the output layer is a conv layer, len(dim_opt) = 3 but we only use the out_c dimension to create slopes/bias. + # If the output layer is a conv layer, len(dim_opt) = 3 but we only + # use the out_c dimension to create slopes/bias. # Variables are shared among out_h, out_w dimensions so far. - dim = dim_opt if isinstance(dim_opt, int) else dim_opt[0] + if isinstance(dim_opt, int): + dim = dim_opt + elif isinstance(dim_opt, torch.Size): + dim = prod(dim_opt) + else: + dim = dim_opt[0] self.lw = torch.zeros(2, dim, *x.lower.shape).to(x.lower) else: - # Without optimized bounds, the lw, lb (slope, biase) etc only depend on intermediate layer bounds, + # Without optimized bounds, the lw, lb (slope, biase) etc only + # depend on intermediate layer bounds, # and are shared among different output dimensions. self.lw = torch.zeros_like(x.lower) self.lb = self.lw.clone() self.uw = self.lw.clone() self.ub = self.lw.clone() - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): - self._start = start_node.name + def bound_relax(self, x, init=False, dim_opt=None): + return not_implemented_op(self, 'bound_relax') + def bound_backward(self, last_lA, last_uA, x, start_node=None, + start_shape=None, reduce_bias=True, **kwargs): + self._start = start_node.name if self.opt_stage not in ['opt', 'reuse']: last_A = last_lA if last_lA is not None else last_uA # Returned [(lA, uA)], lbias, ubias - As, lbias, ubias = super().bound_backward(last_lA, last_uA, x) + As, lbias, ubias = super().bound_backward( + last_lA, last_uA, x, reduce_bias=reduce_bias) if isinstance(last_A, Patches): A_prod = As[0][1].patches if As[0][0] is None else As[0][1].patches # FIXME: Unify this function with BoundReLU - # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. + # Save the patch size, which will be used in init_slope() to + # determine the number of optimizable parameters. if start_node is not None: if last_A.unstable_idx is not None: - # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). - self.patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] + # Sparse patches, we need to construct the full patch size: + # (out_c, batch, out_h, out_w, c, h, w). + self.patch_size[start_node.name] = [ + last_A.output_shape[1], A_prod.size(1), + last_A.output_shape[2], last_A.output_shape[3], + A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] else: # Regular patches. self.patch_size[start_node.name] = A_prod.size() return As, lbias, ubias assert self.batch_dim == 0 - if not self.relaxed: - self.init_linear_relaxation(x, dim_opt=start_shape) - self.bound_relax(x) + self.bound_relax(x, init=True, dim_opt=start_shape) def _bound_oneside(last_A, sign=-1): if last_A is None: @@ -247,7 +279,18 @@ def _bound_oneside(last_A, sign=-1): w_neg = maybe_unfold_patches(w_neg, last_A) b_pos = maybe_unfold_patches(b_pos, last_A) b_neg = maybe_unfold_patches(b_neg, last_A) - A_prod, _bias = multiply_by_A_signs(last_A, w_pos, w_neg, b_pos, b_neg) + unstable_idx = kwargs.get('unstable_idx', None) + if unstable_idx is not None: + assert isinstance(unstable_idx, Tensor) and unstable_idx.ndim == 1 + # Shape is (spec, batch, neurons). + # FIXME: Sigmoid and other activation functions should also support + # sparse-spec alpha, so alpha will be created with a smaller shape. + w_pos = self.non_deter_index_select(w_pos, index=unstable_idx, dim=0) + w_neg = self.non_deter_index_select(w_neg, index=unstable_idx, dim=0) + b_pos = self.non_deter_index_select(b_pos, index=unstable_idx, dim=0) + b_neg = self.non_deter_index_select(b_neg, index=unstable_idx, dim=0) + A_prod, _bias = multiply_by_A_signs( + last_A, w_pos, w_neg, b_pos, b_neg, reduce_bias) return A_prod, _bias lA, lbias = _bound_oneside(last_lA, sign=-1) @@ -261,13 +304,7 @@ def _no_bound_parameters(self): ' at least once.') def dump_optimized_params(self): - raise NotImplementedError - - def restore_optimized_params(self): - raise NotImplementedError - - def set_alpha_beta_update_mask(self, mask): - self.alpha_beta_update_mask = mask + return self.alpha - def clean_alpha_beta_update_mask(self): - self.alpha_beta_update_mask = None + def restore_optimized_params(self, alpha): + self.alpha = alpha diff --git a/auto_LiRPA/operators/activations.py b/auto_LiRPA/operators/activations.py index 86bb0f3..27d9fe4 100644 --- a/auto_LiRPA/operators/activations.py +++ b/auto_LiRPA/operators/activations.py @@ -1,957 +1,16 @@ -""" Activation operators or other unary nonlinear operators""" -from typing import Optional, Tuple +""" Activation operators or other unary nonlinear operators, not including +those placed in separate files.""" import torch -from torch import Tensor -from collections import OrderedDict from .base import * -from .clampmult import multiply_by_A_signs from .activation_base import BoundActivation, BoundOptimizableActivation -from .gradient_modules import ReLUGrad -from .solver_utils import grb -from ..utils import unravel_index, logger, prod torch._C._jit_set_profiling_executor(False) torch._C._jit_set_profiling_mode(False) -class BoundRelu(BoundOptimizableActivation): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.options = options - self.relu_options = options.get('relu', 'adaptive') # FIXME: use better names. - self.use_sparse_spec_alpha = options.get('sparse_spec_alpha', False) - self.use_sparse_features_alpha = options.get('sparse_features_alpha', False) - self.beta = self.beta_mask = self.masked_beta = self.sparse_beta = None - self.split_beta_used = False - self.history_beta_used = False - self.flattened_nodes = None - # Save patches size for each output node. - self.patch_size = {} - self.cut_used = False - self.cut_module = None - # Alpha dimension is (2, output_shape, batch, *shape) for ReLU. - self.alpha_batch_dim = 2 - - def init_opt_parameters(self, start_nodes): - ref = self.inputs[0].lower # a reference variable for getting the shape - batch_size = ref.size(0) - self.alpha = OrderedDict() - self.alpha_lookup_idx = OrderedDict() # For alpha with sparse spec dimention. - self.alpha_indices = None # indices of non-zero alphas. - verbosity = self.options.get('verbosity', 0) - - # Alpha can be sparse in both spec dimension, and the C*H*W dimension. - # We first deal with the sparse-feature alpha, which is sparse in the - # C*H*W dimesnion of this layer. - minimum_sparsity = self.options.get('minimum_sparsity', 0.9) - if (hasattr(self.inputs[0], 'lower') and hasattr(self.inputs[0], 'upper') - and self.use_sparse_features_alpha): - # Pre-activation bounds available, we will store the alpha for unstable neurons only. - # Since each element in a batch can have different unstable neurons, - # for simplicity we find a super-set using any(dim=0). - # This can be non-ideal if the x in a batch are very different. - self.alpha_indices = torch.logical_and( - self.inputs[0].lower < 0, self.inputs[0].upper > 0).any(dim=0).nonzero(as_tuple=True) - total_neuron_size = self.inputs[0].lower.numel() // batch_size - if self.alpha_indices[0].size(0) <= minimum_sparsity * total_neuron_size: - # Shape is the number of unstable neurons in this layer. - alpha_shape = [self.alpha_indices[0].size(0)] - # Skip the batch, spec dimension, and find the lower slopes for all unstable neurons. - if len(self.alpha_indices) == 1: - # This layer is after a linear layer. - alpha_init = self.lower_d[:, :, self.alpha_indices[0]] - elif len(self.alpha_indices) == 3: - # This layer is after a conv layer. - alpha_init = self.lower_d[ - :, :, self.alpha_indices[0], self.alpha_indices[1], - self.alpha_indices[2]] - else: - raise ValueError - if verbosity > 0: - print(f'layer {self.name} using sparse-features alpha with shape {alpha_shape}; unstable size {self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({ref.shape})') - else: - alpha_shape = self.shape # Full alpha. - alpha_init = self.lower_d - if verbosity > 0: - print(f'layer {self.name} using full alpha with shape {alpha_shape}; unstable size {self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({ref.shape})') - self.alpha_indices = None # Use full alpha. - else: - alpha_shape = self.shape # Full alpha. - alpha_init = self.lower_d - # Now we start to create alphas for all start nodes. - # When sparse-spec feature is enabled, alpha is created for only - # unstable neurons in start node. - for ns, output_shape, unstable_idx in start_nodes: - if isinstance(output_shape, (list, tuple)): - if len(output_shape) > 1: - size_s = prod(output_shape) # Conv layers. - else: - size_s = output_shape[0] - else: - size_s = output_shape - # unstable_idx may be a tensor (dense layer or conv layer - # with shared alpha), or tuple of 3-d tensors (conv layer with - # non-sharing alpha). - sparsity = float('inf') if unstable_idx is None else unstable_idx.size(0) if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].size(0) - if sparsity <= minimum_sparsity * size_s and self.use_sparse_spec_alpha: - if verbosity > 0: - print(f'layer {self.name} start_node {ns} using sparse-spec alpha with unstable size {sparsity} total_size {size_s} output_shape {output_shape}') - # For fully connected layer, or conv layer with shared alpha per channel. - # shape is (2, sparse_spec, batch, this_layer_shape) - # We create sparse specification dimension, where the spec dimension of alpha only includes slopes for unstable neurons in start_node. - self.alpha[ns] = torch.empty([2, sparsity + 1, batch_size, *alpha_shape], - dtype=torch.float, device=ref.device, requires_grad=True) - self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, sparse_spec) dimensions. - # unstable_idx is a list of used neurons (or channels for BoundConv) for the start_node. - assert unstable_idx.ndim == 1 if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].ndim == 1 - # We only need to the alpha for the unstable neurons in start_node. - indices = torch.arange(1, sparsity + 1, device=alpha_init.device, dtype=torch.long) - if isinstance(output_shape, int) or len(output_shape) == 1: - # Fully connected layers, or conv layer in patches mode with partially shared alpha (pixels in the same channel use the same alpha). - self.alpha_lookup_idx[ns] = torch.zeros(size_s, dtype=torch.long, device=alpha_init.device) - # This lookup table maps the unstable_idx to the actual alpha location in self.alpha[ns]. - # Note that self.alpha[ns][:,0] is reserved for any unstable neurons that are not found in the lookup table. This usually should not - # happen, unless reference bounds are not properly set. - self.alpha_lookup_idx[ns].data[unstable_idx] = indices - else: - # conv layer in matrix mode, or in patches mode but with non-shared alpha. The lookup table is 3-d. - assert len(output_shape) == 3 - self.alpha_lookup_idx[ns] = torch.zeros(output_shape, dtype=torch.long, device=alpha_init.device) - if isinstance(unstable_idx, torch.Tensor): - # Convert the unstable index from flattend 1-d to 3-d. (matrix mode). - unstable_idx_3d = unravel_index(unstable_idx, output_shape) - else: - # Patches mode with non-shared alpha, unstable_idx is already 3d. - unstable_idx_3d = unstable_idx - # Build look-up table. - self.alpha_lookup_idx[ns].data[unstable_idx_3d[0], unstable_idx_3d[1], unstable_idx_3d[2]] = indices - else: - if verbosity > 0: - print(f'layer {self.name} start_node {ns} using full alpha with unstable size {sparsity if unstable_idx is not None else None} total_size {size_s} output_shape {output_shape}') - # alpha shape is (2, spec, batch, this_layer_shape). "this_layer_shape" may still be sparse. - self.alpha[ns] = torch.empty([2, size_s, batch_size, *alpha_shape], - dtype=torch.float, device=ref.device, requires_grad=True) - self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, spec) dimensions - # alpha_lookup_idx can be used for checking if sparse alpha is used or not. - self.alpha_lookup_idx[ns] = None - - def clip_alpha_(self): - for v in self.alpha.values(): - v.data = torch.clamp(v.data, 0., 1.) - - def forward(self, x): - self.shape = x.shape[1:] - if self.flattened_nodes is None: - self.flattened_nodes = x[0].reshape(-1).shape[0] - return F.relu(x) - - def _forward_relaxation(self, x): - self._init_masks(x) - self.mask_pos = self.mask_pos.to(x.lower) - self.mask_both = self.mask_both.to(x.lower) - - upper_k, upper_b = self._relu_upper_bound(x.lower, x.upper) - self.uw = self.mask_pos + self.mask_both * upper_k - self.ub = self.mask_both * upper_b - - if self.opt_stage in ['opt', 'reuse']: - # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape]. - # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape] - # and we do not need its first two dimensions. - lower_k = alpha = self.alpha['_forward'][0, 0] - elif self.relu_options == "same-slope": - lower_k = upper_k - elif self.relu_options == "zero-lb": - lower_k = torch.zeros_like(upper_k) - elif self.relu_options == "one-lb": - lower_k = torch.ones_like(upper_k) - else: - # adaptive - lower_k = torch.gt(torch.abs(x.upper), torch.abs(x.lower)).to(torch.float) - # NOTE #FIXME Saved for initialization bounds for optimization. - # In the backward mode, same-slope bounds are used. - # But here it is using adaptive bounds which seem to be better - # for nn4sys benchmark with loose input bounds. Need confirmation - # for other cases. - self.lower_d = lower_k.detach() # saved for initializing optimized bounds - - self.lw = self.mask_both * lower_k + self.mask_pos - - def bound_dynamic_forward(self, x, max_dim=None, offset=0): - self._init_masks(x) - self.mask_pos = self.mask_pos.to(x.lower) - self.mask_both = self.mask_both.to(x.lower) - - upper_k, upper_b = self._relu_upper_bound(x.lower, x.upper) - w_new = (self.mask_pos.unsqueeze(1) * x.lw - + self.mask_both.unsqueeze(1) * upper_k.unsqueeze(1) * x.lw) - upper_b = self.mask_both * upper_b / 2 - b_new = (self.mask_pos * x.lb - + self.mask_both * upper_k * x.lb + upper_b) - - # Create new variables for unstable ReLU - batch_size = w_new.shape[0] - device = w_new.device - unstable = self.mask_both.view(batch_size, -1) - tot_unstable = int(unstable.sum(dim=-1).max()) - tot_dim = x.tot_dim + tot_unstable - # logger.debug(f'Unstable: {tot_unstable}') - - if offset + w_new.shape[1] < x.tot_dim: - return LinearBound( - w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=tot_dim) - - index = torch.cumsum(unstable, dim=-1).to(torch.int64) - index = (index - (offset + w_new.shape[1] - x.tot_dim)).clamp(min=0) - num_new_dim = int(index.max()) - num_new_dim_actual = min(num_new_dim, max_dim - w_new.shape[1]) - index = index.clamp(max=num_new_dim_actual+1) - w_unstable = torch.zeros(batch_size, num_new_dim_actual + 2, unstable.size(-1), device=device) - x_L_unstable = -torch.ones(batch_size, num_new_dim_actual, device=device) - x_U_unstable = torch.ones(batch_size, num_new_dim_actual, device=device) - w_unstable.scatter_(dim=1, index=index.unsqueeze(1), src=upper_b.view(batch_size, 1, -1), reduce='add') - w_unstable = w_unstable[:, 1:-1].view(batch_size, num_new_dim_actual, *w_new.shape[2:]) - - w_new = torch.cat([w_new, w_unstable], dim=1) - x_L_new = torch.cat([x.x_L, x_L_unstable], dim=-1) - x_U_new = torch.cat([x.x_U, x_U_unstable], dim=-1) - - return LinearBound( - w_new, b_new, w_new, b_new, x_L=x_L_new, x_U=x_U_new, tot_dim=tot_dim) - - - def bound_forward(self, dim_in, x): - self._forward_relaxation(x) - - lb = self.lw * x.lb - ub = self.uw * x.ub + self.ub - - if x.lw is not None: - lw = self.lw.unsqueeze(1) * x.lw - else: - lw = None - if x.uw is not None: - uw = self.uw.unsqueeze(1) * x.uw - else: - uw = None - - if not lw.requires_grad: - del self.mask_both, self.mask_pos - del self.lw, self.uw, self.ub - - return LinearBound(lw, lb, uw, ub) - - @staticmethod - @torch.jit.script - def _relu_upper_bound(lb, ub): - """Upper bound slope and intercept according to CROWN relaxation.""" - # TODO: pre-comple all JIT functions before run. - lb_r = lb.clamp(max=0) - ub_r = ub.clamp(min=0) - ub_r = torch.max(ub_r, lb_r + 1e-8) - upper_d = ub_r / (ub_r - lb_r) - upper_b = - lb_r * upper_d - return upper_d, upper_b - - @staticmethod - def _relu_mask_alpha(lower, upper, lb_lower_d : Optional[Tensor], ub_lower_d : Optional[Tensor]) -> Tuple[Optional[Tensor], Optional[Tensor], Tensor]: - lower_mask = (lower >= 0).requires_grad_(False).to(lower.dtype) - upper_mask = (upper <= 0).requires_grad_(False) - zero_coeffs = upper_mask.all() - no_mask = (1. - lower_mask) * (1. - upper_mask.to(upper.dtype)) - if lb_lower_d is not None: - lb_lower_d = torch.clamp(lb_lower_d, min=0., max=1.) * no_mask + lower_mask - if ub_lower_d is not None: - ub_lower_d = torch.clamp(ub_lower_d, min=0., max=1.) * no_mask + lower_mask - return lb_lower_d, ub_lower_d, zero_coeffs - - def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx): - if x is not None: - lower = x.lower - upper = x.upper - else: - lower = self.lower - upper = self.upper - - # Upper bound slope and intercept according to CROWN relaxation. - upper_d, upper_b = self._relu_upper_bound(lower, upper) - - flag_expand = False - ub_lower_d = lb_lower_d = None - lower_b = None # ReLU does not have lower bound intercept (=0). - alpha_lookup_idx = None # For sparse-spec alpha. - if self.opt_stage in ['opt', 'reuse']: - # Alpha-CROWN. - lower_d = None - # Each alpha has shape (2, output_shape, batch_size, *relu_node_shape]. - # If slope is shared, output_shape will be 1. - # The *relu_node_shape might be sparse (sparse-feature alpha), where the non-zero values are indicated by self.alpha_indices. - # The out_shape might be sparse (sparse-spec alpha), where the non-zero values are indexed by self.alpha_lookup_idx. - if unstable_idx is not None: - # print(f'relu layer {self.name}, start_node {start_node}, unstable_idx {type(unstable_idx)} alpha idx {self.alpha_lookup_idx[start_node.name].size()}') - alpha_lookup_idx = self.alpha_lookup_idx[start_node.name] - if isinstance(unstable_idx, tuple): - # Start node is a conv node. - selected_alpha = self.alpha[start_node.name] - if isinstance(last_lA, Tensor) or isinstance(last_uA, Tensor): - # Start node is a conv node but we received tensors as A matrices. - # Patches mode converted to matrix, or matrix mode used. Need to select accross the spec dimension. - # For this node, since it is in matrix mode, the spec dimension is out_c * out_h * out_w - # Shape is [2, spec, batch, *this_layer_shape] - if alpha_lookup_idx is None: - # Reshape the spec dimension to c*h*w so we can select used alphas based on unstable index. - # Shape becomes [2, out_c, out_h, out_w, batch, *this_layer_shape] - selected_alpha = selected_alpha.view(selected_alpha.size(0), *start_node.output_shape[1:], *selected_alpha.shape[2:]) - selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] - else: - assert alpha_lookup_idx.ndim == 3 - # We only stored some alphas, and A is also sparse, so the unstable_idx must be first translated to real indices. - # alpha shape is (2, sparse_spec_shape, batch_size, *relu_node_shape) where relu_node_shape can also be sparse. - # We use sparse-spec alphas. Need to convert these unstable_idx[0], unstable_idx[1], unstable_idx[0] using lookup table. - _unstable_idx = alpha_lookup_idx[unstable_idx[0], unstable_idx[1], unstable_idx[2]] - selected_alpha = self.non_deter_index_select(selected_alpha, index=_unstable_idx, dim=1) - else: - # Patches mode. Alpha must be selected after unfolding, so cannot be done here. - # Selection is deferred to maybe_unfold() using alpha_lookup_idx. - # For partially shared alpha, its shape is (2, out_c, batch_size, *relu_node_shape). - # For full alpha, its shape is (2, out_c*out_h*out_w, batch_size, *relu_node_shape). - # Both the spec dimension and relu_node_shape dimensions can be sparse. - pass - elif unstable_idx.ndim == 1: - # Start node is a FC node. - # Only unstable neurons of the start_node neurons are used. - assert alpha_lookup_idx is None or alpha_lookup_idx.ndim == 1 - _unstable_idx = alpha_lookup_idx[unstable_idx] if alpha_lookup_idx is not None else unstable_idx - selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=_unstable_idx, dim=1) - elif unstable_idx.ndim == 2: - assert alpha_lookup_idx is None, "sparse spec alpha has not been implemented yet." - # Each element in the batch selects different neurons. - selected_alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) - else: - raise ValueError - else: - # Spec dimension is dense. Alpha must not be created sparsely. - assert self.alpha_lookup_idx[start_node.name] is None - selected_alpha = self.alpha[start_node.name] - # The first dimension is lower/upper intermediate bound. - if last_lA is not None: - lb_lower_d = selected_alpha[0] - if last_uA is not None: - ub_lower_d = selected_alpha[1] - - if self.alpha_indices is not None: - # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only. - # Recover to full alpha first. - def reconstruct_full_alpha(sparse_alpha, full_alpha_shape, alpha_indices): - full_alpha = torch.zeros(full_alpha_shape, dtype=sparse_alpha.dtype, device=sparse_alpha.device) - if len(alpha_indices) == 1: - # Relu after a dense layer. - full_alpha[:, :, alpha_indices[0]] = sparse_alpha - elif len(alpha_indices) == 3: - # Relu after a conv layer. - full_alpha[:, :, alpha_indices[0], alpha_indices[1], alpha_indices[2]] = sparse_alpha - else: - raise ValueError - return full_alpha - sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape - full_alpha_shape = sparse_alpha_shape[:-1] + self.shape - if lb_lower_d is not None: - lb_lower_d = reconstruct_full_alpha(lb_lower_d, full_alpha_shape, self.alpha_indices) - if ub_lower_d is not None: - ub_lower_d = reconstruct_full_alpha(ub_lower_d, full_alpha_shape, self.alpha_indices) - - # condition only on the masked part - if self.alpha_beta_update_mask is not None: - if lb_lower_d is not None: - lb_lower_d_new = lb_lower_d[:, self.alpha_beta_update_mask] - else: - lb_lower_d_new = None - if ub_lower_d is not None: - ub_lower_d_new = ub_lower_d[:, self.alpha_beta_update_mask] - else: - ub_lower_d_new = None - lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d_new, ub_lower_d_new) - else: - lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d, ub_lower_d) - self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = zero_coeffs - flag_expand = True # we already have the spec dimension. - elif self.relu_options == "same-slope": - # the same slope for upper and lower - lower_d = upper_d - elif self.relu_options == "zero-lb": - # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN - lower_d = (upper_d >= 1.0).to(upper_d.dtype) - elif self.relu_options == "one-lb": - # Always use slope 1 as lower bound - lower_d = (upper_d > 0.0).to(upper_d.dtype) - elif self.relu_options == "reversed-adaptive": - lower_d = (upper_d < 0.5).to(upper_d.dtype) - else: - # adaptive - lower_d = (upper_d > 0.5).to(upper_d.dtype) - - # Upper bound always needs an extra specification dimension, since they only depend on lb and ub. - upper_d = upper_d.unsqueeze(0) - upper_b = upper_b.unsqueeze(0) - if not flag_expand: - if self.opt_stage in ['opt', 'reuse']: - # We have different slopes for lower and upper bounds propagation. - lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None - ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None - else: - lower_d = lower_d.unsqueeze(0) - return upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, alpha_lookup_idx - - def bound_backward(self, last_lA, last_uA, x=None, start_node=None, beta_for_intermediate_layers=False, unstable_idx=None): - # Get element-wise CROWN linear relaxations. - upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, alpha_lookup_idx = \ - self._backward_relaxation(last_lA, last_uA, x, start_node, unstable_idx) - # save for calculate babsr score - self.d = upper_d - self.lA = last_lA - # Save for initialization bounds. - self.lower_d = lower_d - - # Choose upper or lower bounds based on the sign of last_A - def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): - if last_A is None: - return None, 0 - # Obtain the new linear relaxation coefficients based on the signs in last_A. - _A, _bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg) - if isinstance(last_A, Patches): - # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. - A_prod = _A.patches - if start_node is not None: - if last_A.unstable_idx is not None: - # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). - self.patch_size[start_node.name] = [last_A.output_shape[1], A_prod.size(1), last_A.output_shape[2], last_A.output_shape[3], A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] - else: - # Regular patches. - self.patch_size[start_node.name] = A_prod.size() - return _A, _bias - - ######## A problem with patches mode for cut constraint start ########## - # There are cases that the node that is in the constraint but not selected by the patches for the output node - # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum() - # we should force these beta to be 0 to disable the effect of these constraints - A = last_lA if last_lA is not None else last_uA - current_layer_shape = x.lower.size()[1:] - if self.cut_used and type(A) is Patches: - self.cut_module.patch_trick(start_node, self.name, A, current_layer_shape) - ######## A problem with patches mode for cut constraint end ########## - - if self.cut_used: - # propagate postrelu node in cut constraints - last_lA, last_uA = self.cut_module.relu_cut( - start_node, self.name, last_lA, last_uA, current_layer_shape, unstable_idx, - batch_mask=self.alpha_beta_update_mask) - - # In patches mode we might need an unfold. - # lower_d, upper_d, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None - upper_d = maybe_unfold_patches(upper_d, last_lA if last_lA is not None else last_uA) - lower_d = maybe_unfold_patches(lower_d, last_lA if last_lA is not None else last_uA) - upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA) - lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA) # for ReLU it is always None; keeping it here for completeness. - # ub_lower_d and lb_lower_d might have sparse spec dimension, so they may need alpha_lookup_idx to convert to actual spec dim. - ub_lower_d = maybe_unfold_patches(ub_lower_d, last_uA, alpha_lookup_idx=alpha_lookup_idx) - # optimizable slope lb_lower_d: spec (only channels in spec layer), batch, current_c, current_w, current_h - # patches mode lb_lower_d after unfold: unstable, batch, in_C, H, W - lb_lower_d = maybe_unfold_patches(lb_lower_d, last_lA, alpha_lookup_idx=alpha_lookup_idx) - - if self.cut_used: - I = (x.lower < 0) * (x.upper > 0) - # propagate integer var of relu neuron (arelu) in cut constraints through relu layer - lA, uA, lbias, ubias = self.cut_module.arelu_cut( - start_node, self.name, last_lA, last_uA, lower_d, upper_d, - lower_b, upper_b, lb_lower_d, ub_lower_d, I, x, self.patch_size, - current_layer_shape, unstable_idx, - batch_mask=self.alpha_beta_update_mask) - else: - uA, ubias = _bound_oneside( - last_uA, upper_d, ub_lower_d if lower_d is None else lower_d, - upper_b, lower_b) - lA, lbias = _bound_oneside( - last_lA, lb_lower_d if lower_d is None else lower_d, upper_d, - lower_b, upper_b) - - # Regular Beta CROWN with single neuron split - def _beta_crown_single_neuron_splits(A, uA, lA, unstable_idx): - if type(A) is Patches: - if self.options.get('enable_opt_interm_bounds', False): - # expand sparse_beta to full beta - beta_values = (self.sparse_beta[start_node.name] * self.sparse_beta_sign[start_node.name]) - beta_indices = self.sparse_beta_loc[start_node.name] - self.masked_beta = torch.zeros(2, *self.shape).reshape(2, -1).to(A.patches.dtype) - self.non_deter_scatter_add(self.masked_beta, dim=1, index=beta_indices, src=beta_values.to(self.masked_beta.dtype)) - self.masked_beta = self.masked_beta.reshape(2, *self.shape) - else: - if self.beta is None: - # Beta not used. - return lA, uA - # For patches mode, masked_beta will be used; sparse beta is not supported. - self.masked_beta = (self.beta[0] * self.beta_mask).requires_grad_() - # unfold the beta as patches, size (batch, out_h, out_w, in_c, H, W) - A_patches = A.patches - masked_beta_unfolded = inplace_unfold(self.masked_beta, kernel_size=A_patches.shape[-2:], padding=A.padding, stride=A.stride, inserted_zeros=A.inserted_zeros, output_padding=A.output_padding) - if A.unstable_idx is not None: - masked_beta_unfolded = masked_beta_unfolded.permute(1, 2, 0, 3, 4, 5) - # After selection, the shape is (unstable_size, batch, in_c, H, W). - masked_beta_unfolded = masked_beta_unfolded[A.unstable_idx[1], A.unstable_idx[2]] - else: - # Add the spec (out_c) dimension. - masked_beta_unfolded = masked_beta_unfolded.unsqueeze(0) - if self.alpha_beta_update_mask is not None: - masked_beta_unfolded = masked_beta_unfolded[self.alpha_beta_update_mask] - if uA is not None: - uA = uA.create_similar(uA.patches + masked_beta_unfolded) - if lA is not None: - lA = lA.create_similar(lA.patches - masked_beta_unfolded) - elif type(A) is Tensor: - if self.options.get('enable_opt_interm_bounds', False): - # For matrix mode, beta is sparse. - beta_values = (self.sparse_beta[start_node.name] * self.sparse_beta_sign[start_node.name]).expand(lA.size(0), -1, -1) - # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension. - beta_indices = self.sparse_beta_loc[start_node.name].unsqueeze(0).expand(lA.size(0), -1, -1) - else: - # For matrix mode, beta is sparse. - beta_values = (self.sparse_beta * self.sparse_beta_sign).expand(lA.size(0), -1, -1) - # self.single_beta_loc has shape [batch, max_single_split]. Need to expand at the specs dimension. - beta_indices = self.sparse_beta_loc.unsqueeze(0).expand(lA.size(0), -1, -1) - # For conv layer, the last dimension is flattened in indices. - prev_size = A.size() - if self.alpha_beta_update_mask is not None: - beta_indices = beta_indices[:, self.alpha_beta_update_mask] - beta_values = beta_values[:, self.alpha_beta_update_mask] - if uA is not None: - uA = self.non_deter_scatter_add(uA.view(uA.size(0), uA.size(1), -1), dim=2, index=beta_indices, src=beta_values.to(uA.dtype)) - uA = uA.view(prev_size) - if lA is not None: - lA = self.non_deter_scatter_add(lA.view(lA.size(0), lA.size(1), -1), dim=2, index=beta_indices, src=beta_values.neg().to(lA.dtype)) - lA = lA.view(prev_size) - else: - raise RuntimeError(f"Unknown type {type(A)} for A") - return lA, uA - - if self.cut_used: - # propagate prerelu node in cut constraints - lA, uA = self.cut_module.pre_cut(start_node, self.name, lA, uA, current_layer_shape, unstable_idx, - batch_mask=self.alpha_beta_update_mask) - self.masked_beta_lower = self.masked_beta_upper = None - if self.options.get('optimize_bound_args', {}).get('enable_beta_crown', False) and self.sparse_beta is not None: - if self.options.get('optimize_bound_args', {}).get('single_node_split', False): - # Beta-CROWN: each split constraint only has single neuron (e.g., second ReLU neuron > 0). - A = lA if lA is not None else uA - lA, uA = _beta_crown_single_neuron_splits(A, uA, lA, unstable_idx) - # The code block below is for debugging and will be removed (until the end of this function). - # elif False and not self.options.get('optimize_bound_args', {}).get('single_node_split', True): - # # Improved Beta-CROWN: (1) general split constraints: each split constraint have multiple neuron - # # (e.g., second ReLU neuron > 0); (2) intermediate Relu bounds refinement with the general split constraints. - # A = uA if uA is not None else lA - # lA, uA = _beta_crown_multi_neuron_splits(x, A, uA, lA, unstable_idx, start_node) - # print(lA.sum(), uA.sum()) - # exit() - - return [(lA, uA)], lbias, ubias - - def interval_propagate(self, *v): - h_L, h_U = v[0][0], v[0][1] - return F.relu(h_L), F.relu(h_U) - - def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): - # e.g., last layer input gurobi vars (8,16,16) - gvars_array = np.array(v[0]) - this_layer_shape = gvars_array.shape - assert gvars_array.shape == self.output_shape[1:] - - pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1) - pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1) - - new_layer_gurobi_vars = [] - relu_integer_vars = [] - new_relu_layer_constrs = [] - # predefined zero variable shared in the whole solver model - zero_var = model.getVarByName("zero") - - for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)): - pre_ub = pre_ubs[neuron_idx] - pre_lb = pre_lbs[neuron_idx] - - if pre_lb >= 0: - # ReLU is always passing - var = pre_var - elif pre_ub <= 0: - var = zero_var - else: - ub = pre_ub - - var = model.addVar(ub=ub, lb=pre_lb, - obj=0, - vtype=grb.GRB.CONTINUOUS, - name=f'ReLU{self.name}_{neuron_idx}') - - if model_type == "mip" or model_type == "lp_integer": - # binary indicator - if model_type == "mip": - a = model.addVar(vtype=grb.GRB.BINARY, name=f'aReLU{self.name}_{neuron_idx}') - elif model_type == "lp_integer": - a = model.addVar(ub=1, lb=0, vtype=grb.GRB.CONTINUOUS, name=f'aReLU{self.name}_{neuron_idx}') - relu_integer_vars.append(a) - - new_relu_layer_constrs.append( - model.addConstr(pre_var - pre_lb * (1 - a) >= var, - name=f'ReLU{self.name}_{neuron_idx}_a_0')) - new_relu_layer_constrs.append( - model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1')) - new_relu_layer_constrs.append( - model.addConstr(pre_ub * a >= var, name=f'ReLU{self.name}_{neuron_idx}_a_2')) - new_relu_layer_constrs.append( - model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_3')) - - elif model_type == "lp": - new_relu_layer_constrs.append( - model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_0')) - new_relu_layer_constrs.append( - model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1')) - new_relu_layer_constrs.append(model.addConstr( - pre_ub * pre_var - (pre_ub - pre_lb) * var >= pre_ub * pre_lb, - name=f'ReLU{self.name}_{neuron_idx}_a_2')) - - else: - print(f"gurobi model type {model_type} not supported!") - - new_layer_gurobi_vars.append(var) - - new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist() - if model_type in ["mip", "lp_integer"]: - self.integer_vars = relu_integer_vars - self.solver_vars = new_layer_gurobi_vars - self.solver_constrs = new_relu_layer_constrs - model.update() - - def dump_optimized_params(self): - return { - 'alpha': self.alpha, - 'alpha_lookup_idx': self.alpha_lookup_idx, - 'alpha_indices': self.alpha_indices - } - - def restore_optimized_params(self, opt_var_dict): - self.alpha, self.alpha_lookup_idx, self.alpha_indices = \ - opt_var_dict['alpha'], opt_var_dict['alpha_lookup_idx'], opt_var_dict['alpha_indices'] - - def build_gradient_node(self, grad_upstream): - node_grad = ReLUGrad() - grad_input = (grad_upstream, self.inputs[0].forward_value) - # An extra node is needed to consider the state of ReLU activation - grad_extra_nodes = [self.inputs[0]] - return node_grad, grad_input, grad_extra_nodes - - -class BoundLeakyRelu(BoundActivation): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.options = options.get('relu') - self.alpha = attr['alpha'] - - def forward(self, x): - return F.leaky_relu(x, negative_slope=self.alpha) - - def bound_backward(self, last_lA, last_uA, x=None, start_node=None, start_shape=None): - if x is not None: - lb_r = x.lower.clamp(max=0) - ub_r = x.upper.clamp(min=0) - else: - lb_r = self.lower.clamp(max=0) - ub_r = self.upper.clamp(min=0) - ub_r = torch.max(ub_r, lb_r + 1e-8) - upper_d = (ub_r - self.alpha * lb_r) / (ub_r - lb_r) - upper_b = - lb_r * upper_d + self.alpha * lb_r - - if self.options == "same-slope": - # the same slope for upper and lower - lower_d = upper_d - elif self.options == "zero-lb": - # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN - lower_d = (upper_d >= 1.0).to(upper_d.dtype) + (upper_d < 1.0).to(upper_d.dtype) * self.alpha - elif self.options == "one-lb": - # Always use slope 1 as lower bound - lower_d = (upper_d > 0.0).to(upper_d.dtype)+ (upper_d <= 0.0).to(upper_d.dtype) * self.alpha - else: - lower_d = (upper_d > 0.5).to(upper_d.dtype) + (upper_d <= 0.5).to(upper_d.dtype)* self.alpha - - upper_d = upper_d.unsqueeze(0) - lower_d = lower_d.unsqueeze(0) - # Choose upper or lower bounds based on the sign of last_A - uA = lA = None - ubias = lbias = 0 - if last_uA is not None: - neg_uA = last_uA.clamp(max=0) - pos_uA = last_uA.clamp(min=0) - uA = upper_d * pos_uA + lower_d * neg_uA - ubias = self.get_bias(pos_uA, upper_b) - if last_lA is not None: - neg_lA = last_lA.clamp(max=0) - pos_lA = last_lA.clamp(min=0) - lA = upper_d * neg_lA + lower_d * pos_lA - lbias = self.get_bias(neg_lA, upper_b) - return [(lA, uA)], lbias, ubias - - def dump_optimized_params(self): - return self.alpha - - def restore_optimized_params(self, alpha): - self.alpha = alpha - - -class BoundTanh(BoundOptimizableActivation): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.precompute_relaxation('tanh', torch.tanh, self.dtanh) - # Alpha dimension is (4, 2, output_shape, batch, *shape) for Tanh. - self.alpha_batch_dim = 3 - - def opt_init(self): - super().opt_init() - self.tp_both_lower_init = {} - self.tp_both_upper_init = {} - - def init_opt_parameters(self, start_nodes): - l, u = self.inputs[0].lower, self.inputs[0].upper - shape = l.shape - for ns, size_s, _ in start_nodes: - if isinstance(size_s, torch.Size): - size_s = prod(size_s) - self.alpha[ns] = torch.empty(4, 2, size_s, *shape, device=l.device) - self.alpha[ns].data[:2] = ((l + u) / 2).unsqueeze(0).expand(2, 2, size_s, *shape) - self.alpha[ns].data[2] = self.tp_both_lower_init[ns].expand(2, size_s, *shape) - self.alpha[ns].data[3] = self.tp_both_upper_init[ns].expand(2, size_s, *shape) - - def dtanh(self, x): - # to avoid bp error when cosh is too large - # cosh(25.0)**2 > 1e21 - mask = torch.lt(torch.abs(x), 25.0).to(x.dtype) - cosh = torch.cosh(mask * x + 1 - mask) - return mask * (1. / cosh.pow(2)) - - @torch.no_grad() - def precompute_relaxation(self, name, func, dfunc, x_limit = 500): - """ - This function precomputes the tangent lines that will be used as lower/upper bounds for S-shapes functions. - """ - self.x_limit = x_limit - self.step_pre = 0.01 - self.num_points_pre = int(self.x_limit / self.step_pre) - max_iter = 100 - - logger.debug('Precomputing relaxation for {}'.format(name)) - - def check_lower(upper, d): - """Given two points upper, d (d <= upper), check if the slope at d will be less than f(upper) at upper.""" - k = dfunc(d) - # Return True if the slope is a lower bound. - return k * (upper - d) + func(d) <= func(upper) - - def check_upper(lower, d): - """Given two points lower, d (d >= lower), check if the slope at d will be greater than f(lower) at lower.""" - k = dfunc(d) - # Return True if the slope is a upper bound. - return k * (lower - d) + func(d) >= func(lower) - - # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function. - upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) - r = torch.zeros_like(upper) - # Initial guess, the tangent line is at -1. - l = -torch.ones_like(upper) - while True: - # Check if the tangent line at the guessed point is an lower bound at f(upper). - checked = check_lower(upper, l).int() - # If the initial guess is not smaller enough, then double it (-2, -4, etc). - l = checked * l + (1 - checked) * (l * 2) - if checked.sum() == l.numel(): - break - # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper). - # We want to further tighten this bound by moving it closer to 0. - for t in range(max_iter): - # Binary search. - m = (l + r) / 2 - checked = check_lower(upper, m).int() - l = checked * m + (1 - checked) * l - r = checked * r + (1 - checked) * m - # At upper, a line with slope l is guaranteed to lower bound the function. - self.d_lower = l.clone() - - # Do the same again: - # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function. - lower = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) - l = torch.zeros_like(upper) - r = torch.ones_like(upper) - while True: - checked = check_upper(lower, r).int() - r = checked * r + (1 - checked) * (r * 2) - if checked.sum() == l.numel(): - break - for t in range(max_iter): - m = (l + r) / 2 - checked = check_upper(lower, m).int() - l = (1 - checked) * m + checked * l - r = (1 - checked) * r + checked * m - self.d_upper = r.clone() - - logger.debug('Done') - - def forward(self, x): - return torch.tanh(x) - - def bound_relax_impl(self, x, func, dfunc): - # When self.x_limit is large enough, torch.tanh(self.x_limit)=1, - # and thus clipping is valid - lower = x.lower.clamp(min=-self.x_limit) - upper = x.upper.clamp(max=self.x_limit) - y_l, y_u = func(lower), func(upper) - - min_preact = 1e-6 - mask_close = (upper - lower) < min_preact - # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)). - k_direct = k = torch.where(mask_close, - dfunc(upper), (y_u - y_l) / (upper - lower).clamp(min=min_preact)) - - # Fixed bounds that cannot be optimized. self.mask_neg are the masks for neurons with upper bound <= 0. - # Upper bound for the case of input lower bound <= 0, is always the direct line. - self.add_linear_relaxation(mask=self.mask_neg, type='upper', k=k, x0=lower, y0=y_l) - # Lower bound for the case of input upper bound >= 0, is always the direct line. - self.add_linear_relaxation(mask=self.mask_pos, type='lower', k=k, x0=lower, y0=y_l) - - # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed. - # Note that for neurons with also input lower bound >=0, they will be masked later. - index = torch.max( - torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), - (upper / self.step_pre).to(torch.long).reshape(-1) - ) + 1 - # Lookup the lower bound slope from the pre-computed table. - d_lower = torch.index_select(self.d_lower, 0, index).view(lower.shape) - - # Indices of neurons with lower bound <=0, whose optimal slope to upper bound the function was pre-computed. - index = torch.max( - torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), - (lower / -self.step_pre).to(torch.long).reshape(-1) - ) + 1 - d_upper = torch.index_select(self.d_upper, 0, index).view(upper.shape) - - if self.opt_stage in ['opt', 'reuse']: - if not hasattr(self, 'alpha'): - # Raise an error if alpha is not created. - self._no_bound_parameters() - ns = self._start - - # Clipping is done here rather than after `opt.step()` call - # because it depends on pre-activation bounds - self.alpha[ns].data[0, :] = torch.max(torch.min(self.alpha[ns][0, :], upper), lower) - self.alpha[ns].data[1, :] = torch.max(torch.min(self.alpha[ns][1, :], upper), lower) - self.alpha[ns].data[2, :] = torch.min(self.alpha[ns][2, :], d_lower) - self.alpha[ns].data[3, :] = torch.max(self.alpha[ns][3, :], d_upper) - - # shape [2, out_c, n, c, h, w]. - tp_pos = self.alpha[ns][0] - tp_neg = self.alpha[ns][1] - tp_both_lower = self.alpha[ns][2] - tp_both_upper = self.alpha[ns][3] - - # No need to use tangent line, when the tangent point is at the left - # side of the preactivation lower bound. Simply connect the two sides. - mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower)) - self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) - self.add_linear_relaxation( - mask=torch.logical_xor(self.mask_both, mask_direct), type='lower', - k=dfunc(tp_both_lower), x0=tp_both_lower, - y0=self.forward(tp_both_lower)) - - mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper)) - self.add_linear_relaxation(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) - self.add_linear_relaxation( - mask=torch.logical_xor(self.mask_both, mask_direct), type='upper', - k=dfunc(tp_both_upper), x0=tp_both_upper, - y0=self.forward(tp_both_upper)) - - self.add_linear_relaxation( - mask=self.mask_neg, type='lower', - k=dfunc(tp_neg), x0=tp_neg, y0=self.forward(tp_neg)) - self.add_linear_relaxation( - mask=self.mask_pos, type='upper', - k=dfunc(tp_pos), x0=tp_pos, y0=self.forward(tp_pos)) - else: - # Not optimized (vanilla CROWN bound). - # Use the middle point slope as the lower/upper bound. Not optimized. - m = (lower + upper) / 2 - y_m = func(m) - k = dfunc(m) - # Lower bound is the middle point slope for the case input upper bound <= 0. - # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). - self.add_linear_relaxation(mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m) - # Upper bound is the middle point slope for the case input lower bound >= 0. - # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). - self.add_linear_relaxation(mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m) - - # Now handle the case where input lower bound <=0 and upper bound >= 0. - # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound. - k = dfunc(d_lower) - y0 = func(d_lower) - if self.opt_stage == 'init': - # Initialize optimizable slope. - ns = self._start - self.tp_both_lower_init[ns] = d_lower.detach() - # Another possibility is to use the direct line as the lower bound, when this direct line does not intersect with f. - # This is only valid when the slope at the input lower bound has a slope greater than the direct line. - mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower)) - self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) - # Otherwise we do not use the direct line, we use the d_lower slope. - self.add_linear_relaxation( - mask=torch.logical_xor(self.mask_both, mask_direct), - type='lower', k=k, x0=d_lower, y0=y0) - - # Do the same for the upper bound side when input lower bound <=0 and upper bound >= 0. - k = dfunc(d_upper) - y0 = func(d_upper) - if self.opt_stage == 'init': - ns = self._start - self.tp_both_upper_init[ns] = d_upper.detach() - self.tmp_lower = x.lower.detach() - self.tmp_upper = x.upper.detach() - mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper)) - self.add_linear_relaxation(mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) - self.add_linear_relaxation( - mask=torch.logical_xor(self.mask_both, mask_direct), - type='upper', k=k, x0=d_upper, y0=y0) - - def bound_relax(self, x): - self.bound_relax_impl(x, torch.tanh, self.dtanh) - - def dump_optimized_params(self): - return self.alpha - - def restore_optimized_params(self, alpha): - self.alpha = alpha - - -class BoundSigmoid(BoundTanh): - def __init__(self, attr, inputs, output_index, options): - super(BoundTanh, self).__init__(attr, inputs, output_index, options) - self.precompute_relaxation('sigmoid', torch.sigmoid, self.dsigmoid) - # Alpha dimension is (4, 2, output_shape, batch, *shape) for S-shaped functions. - self.alpha_batch_dim = 3 - - def forward(self, x): - return torch.sigmoid(x) - - def dsigmoid(self, x): - return torch.sigmoid(x) * (1 - torch.sigmoid(x)) - - def bound_relax(self, x): - self.bound_relax_impl(x, torch.sigmoid, self.dsigmoid) - - class BoundSoftplus(BoundActivation): - def __init__(self, attr, inputs, output_index, options): - super(BoundSoftplus, self).__init__(attr, inputs, output_index, options) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) self.softplus = nn.Softplus() def forward(self, x): @@ -959,13 +18,10 @@ def forward(self, x): class BoundAbs(BoundActivation): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - def forward(self, x): return x.abs() - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): x_L = x.lower.clamp(max=0) x_U = torch.max(x.upper.clamp(min=0), x_L + 1e-8) mask_neg = x_U <= 0 @@ -1011,25 +67,22 @@ def interval_propagate(self, *v): class BoundATenHeaviside(BoundOptimizableActivation): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.alpha_batch_dim = 2 - def forward(self, *x): self.input_shape = x[0].shape # x[0]: input; x[1]: value when x == 0 return torch.heaviside(x[0], x[1]) - def init_opt_parameters(self, start_nodes): + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" l = self.inputs[0].lower - for ns, size_s, _ in start_nodes: - self.alpha[ns] = torch.zeros_like(l).unsqueeze(0).repeat(2, *[1] * l.ndim).requires_grad_(True) + return torch.zeros_like(l).unsqueeze(0).repeat(2, *[1] * l.ndim) - def clip_alpha_(self): + def clip_alpha(self): for v in self.alpha.values(): v.data = torch.clamp(v.data, 0., 1.) - def bound_backward(self, last_lA, last_uA, *x, start_node=None, start_shape=None): + def bound_backward(self, last_lA, last_uA, *x, start_node=None, + start_shape=None, **kwargs): x = x[0] if x is not None: lb_r = x.lower diff --git a/auto_LiRPA/operators/add_sub.py b/auto_LiRPA/operators/add_sub.py new file mode 100644 index 0000000..ebdd586 --- /dev/null +++ b/auto_LiRPA/operators/add_sub.py @@ -0,0 +1,158 @@ +from .base import * +from .solver_utils import grb + + +class BoundAdd(Bound): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + options = options or {} + # FIXME: This is not the right way to enable patches mode. + # Instead we must traverse the graph and determine when patches mode needs to be used. + + self.mode = options.get("conv_mode", "matrix") + + def forward(self, x, y): + self.x_shape = x.shape + self.y_shape = y.shape + return x + y + + def bound_backward(self, last_lA, last_uA, x, y, **kwargs): + def _bound_oneside(last_A, w): + if last_A is None: + return None + return self.broadcast_backward(last_A, w) + + uA_x = _bound_oneside(last_uA, x) + uA_y = _bound_oneside(last_uA, y) + lA_x = _bound_oneside(last_lA, x) + lA_y = _bound_oneside(last_lA, y) + return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0 + + def bound_forward(self, dim_in, x, y): + lb, ub = x.lb + y.lb, x.ub + y.ub + + def add_w(x_w, y_w, x_b, y_b): + if x_w is None and y_w is None: + return None + elif x_w is not None and y_w is not None: + return x_w + y_w + elif y_w is None: + return x_w + torch.zeros_like(y_b) + else: + return y_w + torch.zeros_like(x_b) + + lw = add_w(x.lw, y.lw, x.lb, y.lb) + uw = add_w(x.uw, y.uw, x.ub, y.ub) + + return LinearBound(lw, lb, uw, ub) + + def interval_propagate(self, x, y): + assert (not isinstance(y, Tensor)) + return x[0] + y[0], x[1] + y[1] + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if isinstance(v[0], Tensor) and isinstance(v[1], Tensor): + # constants if both inputs are tensors + self.solver_vars = self.forward(v[0], v[1]) + return + # we have both gurobi vars as inputs + this_layer_shape = self.output_shape + gvar_array1 = np.array(v[0]) + gvar_array2 = np.array(v[1]) + assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:] + + # flatten to create vars and constrs first + gvar_array1 = gvar_array1.reshape(-1) + gvar_array2 = gvar_array2.reshape(-1) + new_layer_gurobi_vars = [] + for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)): + var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(var == (var1 + var2), name=f'lay{self.name}_{neuron_idx}_eq') + new_layer_gurobi_vars.append(var) + + # reshape to the correct list shape of solver vars + self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist() + model.update() + + +class BoundSub(Bound): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used. + self.mode = options.get("conv_mode", "matrix") + + def forward(self, x, y): + self.x_shape = x.shape + self.y_shape = y.shape + return x - y + + def bound_backward(self, last_lA, last_uA, x, y, **kwargs): + def _bound_oneside(last_A, w, sign=-1): + if last_A is None: + return None + if isinstance(last_A, torch.Tensor): + return self.broadcast_backward(sign * last_A, w) + elif isinstance(last_A, Patches): + if sign == 1: + # Patches shape requires no broadcast. + return last_A + else: + # Multiply by the sign. + return last_A.create_similar(sign * last_A.patches) + else: + raise ValueError(f'Unknown last_A type {type(last_A)}') + + uA_x = _bound_oneside(last_uA, x, sign=1) + uA_y = _bound_oneside(last_uA, y, sign=-1) + lA_x = _bound_oneside(last_lA, x, sign=1) + lA_y = _bound_oneside(last_lA, y, sign=-1) + return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0 + + def bound_forward(self, dim_in, x, y): + lb, ub = x.lb - y.ub, x.ub - y.lb + + def add_w(x_w, y_w, x_b, y_b): + if x_w is None and y_w is None: + return None + elif x_w is not None and y_w is not None: + return x_w + y_w + elif y_w is None: + return x_w + torch.zeros_like(y_b) + else: + return y_w + torch.zeros_like(x_b) + + lw = add_w(x.lw, -y.uw, x.lb, y.lb) + uw = add_w(x.uw, -y.lw, x.ub, y.ub) + + return LinearBound(lw, lb, uw, ub) + + def interval_propagate(self, x, y): + return x[0] - y[1], x[1] - y[0] + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if isinstance(v[0], Tensor) and isinstance(v[1], Tensor): + # constants if both inputs are tensors + self.solver_vars = self.forward(v[0], v[1]) + return + # we have both gurobi vars as inputs + this_layer_shape = self.output_shape + gvar_array1 = np.array(v[0]) + gvar_array2 = np.array(v[1]) + assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:] + + # flatten to create vars and constrs first + gvar_array1 = gvar_array1.reshape(-1) + gvar_array2 = gvar_array2.reshape(-1) + new_layer_gurobi_vars = [] + for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)): + var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'lay{self.name}_{neuron_idx}') + model.addConstr(var == (var1 - var2), name=f'lay{self.name}_{neuron_idx}_eq') + new_layer_gurobi_vars.append(var) + + # reshape to the correct list shape of solver vars + self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist() + model.update() diff --git a/auto_LiRPA/operators/base.py b/auto_LiRPA/operators/base.py index 69b548e..45ad854 100644 --- a/auto_LiRPA/operators/base.py +++ b/auto_LiRPA/operators/base.py @@ -1,4 +1,5 @@ """ Base class and functions for implementing bound operators""" +from typing import Optional, List import warnings import torch import torch.nn as nn @@ -66,14 +67,14 @@ def get_perturbation(interval): if isinstance(interval.ptb, PerturbationLpNorm): return interval.ptb.norm, interval.ptb.eps elif isinstance(interval.ptb, PerturbationSynonym): - return np.inf, 1.0 + return torch.inf, 1.0 elif isinstance(interval.ptb, PerturbationL0Norm): return 0, interval.ptb.eps, interval.ptb.ratio else: raise RuntimeError("get_perturbation() does not know how to handle {}".format(type(interval.ptb))) else: # Tuple object. Assuming L infinity norm lower and upper bounds. - return np.inf, np.nan + return torch.inf, np.nan @staticmethod @@ -108,7 +109,7 @@ def __init__(self, attr=None, inputs=None, output_index=0, options=None): attr = {} if attr is None else attr inputs = [] if inputs is None else inputs options = {} if options is None else options - self.name = None + self.name: Optional[str] = None self.output_name = [] self.device = attr.get('device') self.attr, self.inputs, self.output_index, self.options = \ @@ -118,8 +119,16 @@ def __init__(self, attr=None, inputs=None, output_index=0, options=None): self.from_input = False self.bounded = False self.IBP_rets = None + self.requires_input_bounds = [] + # If True, when we are computing intermediate bounds for these ops, + # we simply use IBP to propagate bounds from its input nodes + # instead of CROWN. Currently only operators with a single input can be + # supported. + self.ibp_intermediate = False + self.splittable = False # Determine if this node has a perturbed output or not. The function BoundedModule._mark_perturbed_nodes() will set this property. self.perturbed = False + self.never_perturbed = False if options is not None and 'loss_fusion' in options: self.loss_fusion = options['loss_fusion'] else: @@ -133,11 +142,55 @@ def __init__(self, attr=None, inputs=None, output_index=0, options=None): # If set to true, the A matrix accumulated on this node is 0. self.zero_lA_mtx = False self.zero_uA_mtx = False - self.patches_start = False + self.alpha_beta_update_mask = None + + def __repr__(self, attrs=None): + inputs = ', '.join([node.name for node in self.inputs]) + ret = (f'{self.__class__.__name__}(name={self.name}, ' + f'inputs=[{inputs}], perturbed={self.perturbed}') + if attrs is not None: + for k, v in attrs.items(): + ret += f', {k}={v}' + ret += ')' + return ret - def __repr__(self): - return f'{self.__class__.__name__}(name="{self.name}")' + def are_output_constraints_activated_for_layer( + self: 'Bound', + apply_output_constraints_to: Optional[List[str]], + ): + if apply_output_constraints_to is None: + return False + for layer_type_or_name in apply_output_constraints_to: + if layer_type_or_name.startswith('/'): + if self.name == layer_type_or_name: + return True + else: + assert layer_type_or_name.startswith('Bound'), ( + 'To apply output constraints to tighten layer bounds, pass either the layer name ' + '(starting with "/", e.g. "/input.7") or the layer type (starting with "Bound", ' + 'e.g. "BoundLinear")' + ) + if type(self).__name__ == layer_type_or_name: + return True + return False + + def init_gammas(self, num_constraints): + if not self.are_output_constraints_activated_for_layer( + self.options.get('optimize_bound_args', {}).get('apply_output_constraints_to', []) + ): + return + assert len(self.output_shape) > 0, self + neurons_in_this_layer = 1 + for d in self.output_shape[1:]: + neurons_in_this_layer *= d + init_gamma_value = 0.0 + self.gammas = torch.full((2, num_constraints), init_gamma_value, requires_grad=True, device=self.device) + + def clip_gammas(self): + if not hasattr(self, "gammas"): + return + self.gammas.data = torch.clamp(self.gammas.data, min=0.0) def is_input_perturbed(self, i=0): r"""Check if the i-th input is with perturbation or not.""" @@ -179,20 +232,26 @@ def interval_propagate(self, *v): Returns: bound: The interval bound of this node, in a same format as v[i]. """ - if self.use_default_ibp: + if self.use_default_ibp or self.never_perturbed: return self.default_interval_propagate(*v) else: return not_implemented_op(self, 'interval_propagate') def default_interval_propagate(self, *v): - """For unary monotonous functions or functions for altering shapes only but not values""" + """Default IBP using the forward function. + + For unary monotonous functions or functions for altering shapes only + but not values. + """ if len(v) == 0: return Interval.make_interval(self.forward(), self.forward()) - elif len(v) == 1: - return Interval.make_interval( - self.forward(v[0][0]), self.forward(v[0][1]), v[0]) else: - raise NotImplementedError('default_interval_propagate only supports no more than 1 input node') + if len(v) > 1: + for i in range(1, len(v)): + assert not self.is_input_perturbed(i) + return Interval.make_interval( + self.forward(v[0][0], *[vv[0] for vv in v[1:]]), + self.forward(v[0][1], *[vv[0] for vv in v[1:]]), v[0]) def bound_forward(self, dim_in, *x): r""" @@ -225,7 +284,7 @@ def bound_forward(self, dim_in, *x): def bound_dynamic_forward(self, *x, max_dim=None, offset=0): raise NotImplementedError(f'bound_dynamic_forward is not implemented for {self}.') - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): r""" Function for backward mode bound propagation. @@ -247,33 +306,33 @@ def bound_backward(self, last_lA, last_uA, *x): def broadcast_backward(self, A, x): shape = x.output_shape - batch_dim = max(self.batch_dim, 0) if isinstance(A, Tensor): if x.batch_dim == -1: # final shape of input - shape = torch.Size([A.shape[batch_dim + 1]] + list(shape)) + shape = torch.Size([A.shape[1]] + list(shape)) dims = [] cnt_sum = A.ndim - len(shape) - 1 - for i in range(1, A.ndim): # merge the output dimensions? - if i != self.batch_dim + 1 and cnt_sum > 0: + for i in range(2, A.ndim): # merge the output dimensions? + if cnt_sum > 0: dims.append(i) cnt_sum -= 1 if dims: A = torch.sum(A, dim=dims) else: - dims = list(range(1, 1 + A.ndim - 1 - len(shape))) + dims = list(range(1, A.ndim - len(shape))) if dims: A = torch.sum(A, dim=dims) dims = [] - for i in range(len(shape)): + for i in range(1, len(shape)): # Skip the batch dimension. - # FIXME (05/11/2022): the following condition is not always correct. We should not rely on checking dimension is "1" or not. - if shape[i] == 1 and A.shape[i + 1] != 1 and i != batch_dim: + # FIXME (05/11/2022): the following condition is not always correct. + # We should not rely on checking dimension is "1" or not. + if shape[i] == 1 and A.shape[i + 1] != 1: dims.append(i + 1) if dims: A = torch.sum(A, dim=dims, keepdim=True) - assert (A.shape[2:] == shape[1:]) # skip the spec and batch dimension. + assert A.shape[2:] == shape[1:] # skip the spec and batch dimension. else: pass return A @@ -286,14 +345,14 @@ def build_gradient_node(self, grad_upstream): grad_upstream: Upstream gradient in the gradient back-propagation. Returns: - node_grad (Bound): Gradient node. + module_grad (torch.nn.Module): Gradient node. grad_input (list): Inputs to the gradient node. Values do not matter. We only want the shapes. grad_extra_nodes (list): Extra nodes needed for the gradient. """ - return not_implemented_op(self, 'bound_forward') + return not_implemented_op(self, 'build_gradient_node') def get_bias(self, A, bias): if A is None: @@ -352,6 +411,9 @@ def get_bias(self, A, bias): return NotImplementedError() def make_axis_non_negative(self, axis, shape='input'): + if isinstance(axis, (tuple, list)): + return tuple([self.make_axis_non_negative(item, shape) + for item in axis]) if shape == 'input': shape = self.input_shape elif shape == 'output': @@ -363,8 +425,83 @@ def make_axis_non_negative(self, axis, shape='input'): else: return axis + def update_requires_input_bounds(self): + """Update requires_input_bounds. + + This function is called once we know if the input nodesare perturbed. + """ + pass + + def clamp_interim_bounds(self): + """Clamp intermediate bounds.""" + pass + + def check_constraint_available(self, node, flag=False): + if hasattr(node, 'cstr_interval'): + flag = True + for n in node.inputs: + if not n.from_input: + flag = flag or self.check_constraint_available(n, flag) + return flag + + def _ibp_constraint(self, node, delete_bounds_after_use=False): + def _delete_unused_bounds(node_list): + """Delete bounds from input layers after use to save memory. Used when + sparse_intermediate_bounds_with_ibp is true.""" + if delete_bounds_after_use: + for n in node_list: + del n.cstr_interval + del n.cstr_lower + del n.cstr_upper + + if not node.perturbed and hasattr(node, 'forward_value'): + node.cstr_lower, node.cstr_upper = node.cstr_interval = ( + node.forward_value, node.forward_value) + + to_be_deleted_bounds = [] + if not hasattr(node, 'cstr_interval'): + for n in node.inputs: + if not hasattr(n, 'cstr_interval'): + # Node n does not have interval bounds; we must compute it. + self._ibp_constraint( + n, delete_bounds_after_use=delete_bounds_after_use) + to_be_deleted_bounds.append(n) + inp = [n_pre.cstr_interval for n_pre in node.inputs] + node.cstr_interval = node.interval_propagate(*inp) + + node.cstr_lower, node.cstr_upper = node.cstr_interval + if isinstance(node.cstr_lower, torch.Size): + node.cstr_lower = torch.tensor(node.cstr_lower) + node.cstr_interval = (node.cstr_lower, node.cstr_upper) + if isinstance(node.cstr_upper, torch.Size): + node.cstr_upper = torch.tensor(node.cstr_upper) + node.cstr_interval = (node.cstr_lower, node.cstr_upper) + + if hasattr(node, 'lower'): + node.lower = torch.where(node.lower >= node.cstr_lower, node.lower, + node.cstr_lower) + node.upper = torch.where(node.upper <= node.cstr_upper, node.upper, + node.cstr_upper) + node.interval = (node.lower, node.upper) + + _delete_unused_bounds(to_be_deleted_bounds) + return node.cstr_interval + + def _check_weight_perturbation(self): + weight_perturbation = False + for n in self.inputs[1:]: + if hasattr(n, 'perturbation'): + if n.perturbation is not None: + weight_perturbation = True + if weight_perturbation: + self.requires_input_bounds = list(range(len(self.inputs))) + else: + self.requires_input_bounds = [] + return weight_perturbation + def non_deter_wrapper(self, op, *args, **kwargs): - """Some operations are non-deterministic and deterministic mode will fail. So we temporary disable it.""" + """Some operations are non-deterministic and deterministic mode will fail. + So we temporary disable it.""" if self.options.get('deterministic', False): torch.use_deterministic_algorithms(False) ret = op(*args, **kwargs) diff --git a/auto_LiRPA/operators/bivariate.py b/auto_LiRPA/operators/bivariate.py index baabe31..5b79ada 100644 --- a/auto_LiRPA/operators/bivariate.py +++ b/auto_LiRPA/operators/bivariate.py @@ -1,89 +1,115 @@ """ Bivariate operators""" -import copy +import torch +from torch import Tensor +from typing import Dict, Optional from .base import * -from .nonlinear import BoundSqrt, BoundReciprocal +from .activation_base import BoundOptimizableActivation +from .nonlinear import BoundSqrt from .clampmult import multiply_by_A_signs from ..utils import * from .solver_utils import grb -from .constant import BoundConstant -from .leaf import BoundParams, BoundBuffers -class BoundMul(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.is_constant_op = False - for inp in inputs: - if BoundMul._check_const_input(inp): - # If any of the two inputs are constant, we do not need input bounds. - # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant. - self.is_constant_op = True - if self.is_constant_op: - # One input is constant; no bounds required. - self.requires_input_bounds = [] - else: - # Both inputs are perturbed. Need relaxation. - self.requires_input_bounds = [0, 1] +class MulHelper: + """Handle linear relaxation for multiplication. - @staticmethod - def _check_const_input(inp): - return isinstance(inp, (BoundConstant, BoundBuffers)) or (isinstance(inp, BoundParams) and inp.perturbation is None) + This helper can be used by BoundMul, BoundMatMul, + BoundLinear (with weight perturbation). + """ - def forward(self, x, y): - self.x_shape = x.shape - self.y_shape = y.shape - return x * y + def __init__(self): + pass @staticmethod - def get_bound_mul(x_l, x_u, y_l, y_u): - alpha_l = y_l - beta_l = x_l - gamma_l = -alpha_l * beta_l - - alpha_u = y_u - beta_u = x_l - gamma_u = -alpha_u * beta_u - return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u + def interpolated_relaxation(x_l: Tensor, x_u: Tensor, + y_l: Tensor, y_u: Tensor, + r_l: Optional[Tensor] = None, + r_u: Optional[Tensor] = None + ) -> Tuple[Tensor, Tensor, Tensor, + Tensor, Tensor, Tensor]: + """Interpolate two optimal linear relaxations for optimizable bounds.""" + if r_l is None and r_u is None: + alpha_l, beta_l, gamma_l = y_l, x_l, -y_l * x_l + alpha_u, beta_u, gamma_u = y_u, x_l, -y_u * x_l + return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u + else: + assert isinstance(r_l, Tensor) and isinstance(r_u, Tensor) + # TODO (for zhouxing/qirui): this function may benefit from JIT, + # because it has many element-wise operation which can be fused. + # Need to benchmark and see performance. + alpha_l = (y_l - y_u) * r_l + y_u + beta_l = (x_l - x_u) * r_l + x_u + gamma_l = (y_u * x_u - y_l * x_l) * r_l - y_u * x_u + alpha_u = (y_u - y_l) * r_u + y_l + beta_u = (x_l - x_u) * r_u + x_u + gamma_u = (y_l * x_u - y_u * x_l) * r_u - y_l * x_u + return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u - # Special case when input is x * x. @staticmethod - def get_bound_square(x_l, x_u): - # Lower bound is a z=0 line if x_l and x_u have different signs. - # Otherwise, the lower bound is a tangent line at x_l. - # The lower bound should always be better than IBP. - - # If both x_l and x_u < 0, select x_u. If both > 0, select x_l. - # If x_l < 0 and x_u > 0, we use the z=0 line as the lower bound. - x_m = F.relu(x_l) - F.relu(-x_u) - alpha_l = 2 * x_m - gamma_l = - x_m * x_m - - # Upper bound: connect the two points (x_l, x_l^2) and (x_u, x_u^2). - # The upper bound should always be better than IBP. - alpha_u = x_l + x_u - gamma_u = - x_l * x_u - - # Parameters before the second variable are all zeros, not used. - beta_l = torch.zeros_like(x_l) - beta_u = beta_l - return alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u + def get_relaxation(x_l: Tensor, x_u: Tensor, y_l: Tensor, y_u: Tensor, + opt_stage: Optional[str], + alphas: Optional[Dict[str, Tensor]], + start_name: Optional[str], + ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + if opt_stage in ['opt', 'reuse']: + assert x_l.ndim == y_l.ndim + ns = start_name + alphas[ns].data[:] = alphas[ns].data[:].clamp(min=0, max=1) + return MulHelper.interpolated_relaxation( + x_l, x_u, y_l, y_u, alphas[ns][:2], alphas[ns][2:4]) + else: + return MulHelper.interpolated_relaxation(x_l, x_u, y_l, y_u) @staticmethod - def _relax(x, y): - if x is y: - # A shortcut for x * x. - return BoundMul.get_bound_square(x.lower, x.upper) - - x_l, x_u = x.lower, x.upper - y_l, y_u = y.lower, y.upper - + def get_forward_relaxation(x_l, x_u, y_l, y_u, opt_stage, alpha, start_name): # Broadcast + # FIXME perhaps use a more efficient way x_l = x_l + torch.zeros_like(y_l) x_u = x_u + torch.zeros_like(y_u) y_l = y_l + torch.zeros_like(x_l) y_u = y_u + torch.zeros_like(x_u) + return MulHelper.get_relaxation(x_l, x_u, y_l, y_u, opt_stage, alpha, start_name) + + @staticmethod + def _get_gap(x, y, alpha, beta): + return x * y - alpha * x - beta * y + + +class BoundMul(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.splittable = True + self.mul_helper = MulHelper() + + def forward(self, x, y): + self.x_shape = x.shape + self.y_shape = y.shape + return x * y - return BoundMul.get_bound_mul(x_l, x_u, y_l, y_u) + def get_relaxation_opt(self, x_l, x_u, y_l, y_u): + return self.mul_helper.get_relaxation( + x_l, x_u, y_l, y_u, self.opt_stage, getattr(self, 'alpha', None), + self._start) + + def _init_opt_parameters_impl(self, size_spec, **kwargs): + """Implementation of init_opt_parameters for each start_node.""" + x_l = self.inputs[0].lower + y_l = self.inputs[1].lower + assert x_l.ndim == y_l.ndim + shape = [max(x_l.shape[i], y_l.shape[i]) for i in range(x_l.ndim)] + alpha = torch.ones(4, size_spec, *shape, device=x_l.device) + return alpha + + def bound_relax(self, x, y, init=False, dim_opt=None): + if init: + pass + (alpha_l, beta_l, gamma_l, + alpha_u, beta_u, gamma_u) = self.get_relaxation_opt( + x.lower, x.upper, y.lower, y.upper) + self.lw = [alpha_l, beta_l] + self.lb = gamma_l + self.uw = [alpha_u, beta_u] + self.ub = gamma_u @staticmethod def _multiply_by_const(x, const): @@ -108,133 +134,129 @@ def _multiply_by_const(x, const): else: raise ValueError(f'Unsupported x type {type(x)}') - @staticmethod - def bound_backward_constant(last_lA, last_uA, x, y, op=None): + def bound_backward_constant(self, last_lA, last_uA, x, y, op=None, + reduce_bias=True, **kwargs): + assert reduce_bias op = BoundMul._multiply_by_const if op is None else op # Handle the case of multiplication by a constant. factor = None - if not BoundMul._check_const_input(x): - factor = y.value - if not BoundMul._check_const_input(y): - factor = x.value + if x.perturbed: + factor = y.forward_value + if y.perturbed: + factor = x.forward_value # No need to compute A matrix if it is Constant. - lAx = None if BoundMul._check_const_input(x) or last_lA is None else op(last_lA, factor) - lAy = None if BoundMul._check_const_input(y) or last_lA is None else op(last_lA, factor) - uAx = None if BoundMul._check_const_input(x) or last_uA is None else op(last_uA, factor) - uAy = None if BoundMul._check_const_input(y) or last_uA is None else op(last_uA, factor) - + lAx = (None if not x.perturbed or last_lA is None + else self.broadcast_backward(op(last_lA, factor), x)) + uAx = (None if not x.perturbed or last_uA is None + else self.broadcast_backward(op(last_uA, factor), x)) + lAy = (None if not y.perturbed or last_lA is None + else self.broadcast_backward(op(last_lA, factor), y)) + uAy = (None if not y.perturbed or last_uA is None + else self.broadcast_backward(op(last_uA, factor), y)) return [(lAx, uAx), (lAy, uAy)], 0., 0. - - def bound_backward(self, last_lA, last_uA, x, y): - if self.is_constant_op: - return self.bound_backward_constant(last_lA, last_uA, x, y) + def bound_backward(self, last_lA, last_uA, x, y, start_node=None, **kwargs): + if start_node is not None: + self._start = start_node.name + if self.is_linear_op: + ret = self.bound_backward_constant(last_lA, last_uA, x, y, **kwargs) else: - return self.bound_backward_both_perturbed(last_lA, last_uA, x, y) - - def bound_backward_both_perturbed(self, last_lA, last_uA, x, y): - alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = BoundMul._relax(x, y) + ret = self.bound_backward_both_perturbed( + last_lA, last_uA, x, y, **kwargs) + return ret - alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0) - beta_l, beta_u = beta_l.unsqueeze(0), beta_u.unsqueeze(0) + def bound_backward_both_perturbed(self, last_lA, last_uA, x, y, + reduce_bias=True, **kwargs): + self.bound_relax(x, y) - def _bound_oneside(last_A, - alpha_pos, beta_pos, gamma_pos, - alpha_neg, beta_neg, gamma_neg): + def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, + alpha_neg, beta_neg, gamma_neg, opt=False): if last_A is None: return None, None, 0 if type(last_A) == Patches: - # In patches mode, we need to unfold lower and upper slopes. In matrix mode we simply return. - def _maybe_unfold(d_tensor, last_A): - if d_tensor is None: - return None - - d_shape = d_tensor.size() - # Reshape to 4-D tensor to unfold. - d_tensor = d_tensor.view(-1, *d_shape[-3:]) - # unfold the slope matrix as patches. Patch shape is [spec * batch, out_h, out_w, in_c, H, W). - d_unfolded = inplace_unfold(d_tensor, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) - # Reshape to (spec, batch, out_h, out_w, in_c, H, W); here spec_size is out_c. - d_unfolded_r = d_unfolded.view(*last_A.shape[:3], *d_unfolded.shape[1:]) - if last_A.unstable_idx is not None: - if d_unfolded_r.size(0) == 1: - # Broadcast the spec shape, so only need to select the reset dimensions. - # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W). - d_unfolded_r = d_unfolded_r.squeeze(0).permute(1, 2, 0, 3, 4, 5) - d_unfolded_r = d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] - # output shape: (unstable_size, batch, in_c, H, W). - else: - d_unfolded_r = d_unfolded_r[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] - # For sparse patches, the shape after unfold is (unstable_size, batch_size, in_c, H, W). - # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W). - return d_unfolded_r - # if last_A is not an identity matrix + assert reduce_bias assert last_A.identity == 0 - if last_A.identity == 0: - # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. Here out_c is the spec dimension. - # for patches mode, we need to unfold the alpha_pos/neg and beta_pos/neg - - alpha_pos = _maybe_unfold(alpha_pos, last_A) - alpha_neg = _maybe_unfold(alpha_neg, last_A) - beta_pos = _maybe_unfold(beta_pos, last_A) - beta_neg = _maybe_unfold(beta_neg, last_A) - - gamma_pos = _maybe_unfold(gamma_pos, last_A) - gamma_neg = _maybe_unfold(gamma_neg, last_A) - - patches = last_A.patches - patches_shape = patches.shape - A_x, bias = multiply_by_A_signs(patches.view(*patches_shape[:5], -1, *patches_shape[-2:]), alpha_pos, alpha_neg, gamma_pos, gamma_neg, patches_mode=True) - A_y, _ = multiply_by_A_signs(patches.view(*patches_shape[:5], -1, *patches_shape[-2:]), beta_pos, beta_neg, None, None, patches_mode=True) - A_x = A_x.view(patches_shape) - A_y = A_y.view(patches_shape) - - # broadcast_backward - x_dims = [] - y_dims = [] - - if A_x.shape[A_x.ndim-4] != x.output_shape[len(x.output_shape)-4]: - x_dims.append(A_x.ndim-4) - - if A_y.shape[A_y.ndim-4] != y.output_shape[len(y.output_shape)-4]: - y_dims.append(A_y.ndim-4) - - if len(x_dims) > 0: - A_x = A_x.sum(tuple(x_dims), keepdim=True) - if len(y_dims) > 0: - A_y = A_y.sum(tuple(y_dims), keepdim=True) - - A_x = Patches(A_x, last_A.stride, last_A.padding, A_x.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) - A_y = Patches(A_y, last_A.stride, last_A.padding, A_y.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) - if type(last_A) == Tensor: + # last_A shape: [out_c, batch_size, out_h, out_w, in_c, H, W]. + # Here out_c is the spec dimension. + # for patches mode, we need to unfold the alpha_pos/neg and beta_pos/neg + alpha_pos = maybe_unfold_patches(alpha_pos, last_A) + alpha_neg = maybe_unfold_patches(alpha_neg, last_A) + beta_pos = maybe_unfold_patches(beta_pos, last_A) + beta_neg = maybe_unfold_patches(beta_neg, last_A) + gamma_pos = maybe_unfold_patches(gamma_pos, last_A) + gamma_neg = maybe_unfold_patches(gamma_neg, last_A) + A_x, bias = multiply_by_A_signs( + last_A, alpha_pos, alpha_neg, gamma_pos, gamma_neg) + A_y, _ = multiply_by_A_signs( + last_A, beta_pos, beta_neg, None, None) + elif type(last_A) == Tensor: last_A_pos, last_A_neg = last_A.clamp(min=0), last_A.clamp(max=0) - A_x = last_A_pos * alpha_pos + last_A_neg * alpha_neg - A_y = last_A_pos * beta_pos + last_A_neg * beta_neg + A_x, _ = multiply_by_A_signs(last_A, alpha_pos, alpha_neg, None, None) + A_y, _ = multiply_by_A_signs(last_A, beta_pos, beta_neg, None, None) A_x = self.broadcast_backward(A_x, x) A_y = self.broadcast_backward(A_y, y) - bias = self.get_bias(last_A_pos, gamma_pos) + \ - self.get_bias(last_A_neg, gamma_neg) + if reduce_bias: + if opt: + bias = (torch.einsum('sb...,sb...->sb', + last_A_pos, gamma_pos) + + torch.einsum('sb...,sb...->sb', + last_A_neg, gamma_neg)) + else: + bias = (self.get_bias(last_A_pos, gamma_pos.squeeze(0)) + + self.get_bias(last_A_neg, gamma_neg.squeeze(0))) + else: + assert not opt + bias = last_A_pos * gamma_pos + last_A_neg * gamma_neg + assert len(x.output_shape) == bias.ndim - 1 + assert len(y.output_shape) == bias.ndim - 1 + bias_x = bias_y = bias + for i in range(2, bias.ndim): + if bias_x.shape[i] != x.output_shape[i - 1]: + assert x.output_shape[i - 1] == 1 + bias_x = bias_x.sum(i, keepdim=True) + for i in range(2, bias.ndim): + if bias_y.shape[i] != y.output_shape[i - 1]: + assert y.output_shape[i - 1] == 1 + bias_y = bias_y.sum(i, keepdim=True) + bias = (bias_x, bias_y) + else: + raise NotImplementedError(last_A) return A_x, A_y, bias - lA_x, lA_y, lbias = _bound_oneside( - last_lA, alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u) - uA_x, uA_y, ubias = _bound_oneside( - last_uA, alpha_u, beta_u, gamma_u, alpha_l, beta_l, gamma_l) + alpha_l, beta_l, gamma_l = self.lw[0], self.lw[1], self.lb + alpha_u, beta_u, gamma_u = self.uw[0], self.uw[1], self.ub + + if self.opt_stage in ['opt', 'reuse']: + lA_x, lA_y, lbias = _bound_oneside( + last_lA, alpha_l[0], beta_l[0], gamma_l[0], + alpha_u[0], beta_u[0], gamma_u[0], opt=True) + uA_x, uA_y, ubias = _bound_oneside( + last_uA, alpha_u[1], beta_u[1], gamma_u[1], + alpha_l[1], beta_l[1], gamma_l[1], opt=True) + else: + alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0) + beta_l, beta_u = beta_l.unsqueeze(0), beta_u.unsqueeze(0) + gamma_l, gamma_u = gamma_l.unsqueeze(0), gamma_u.unsqueeze(0) + lA_x, lA_y, lbias = _bound_oneside( + last_lA, alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u) + uA_x, uA_y, ubias = _bound_oneside( + last_uA, alpha_u, beta_u, gamma_u, alpha_l, beta_l, gamma_l) return [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias def bound_forward(self, dim_in, x, y): - if self.is_constant_op: + if self.is_linear_op: raise NotImplementedError return self.bound_forward_both_perturbed(dim_in, x, y) - @staticmethod - def bound_forward_both_perturbed(dim_in, x, y): + def bound_forward_both_perturbed(self, dim_in, x, y): x_lw, x_lb, x_uw, x_ub = x.lw, x.lb, x.uw, x.ub y_lw, y_lb, y_uw, y_ub = y.lw, y.lb, y.uw, y.ub - alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = BoundMul._relax(x, y) + (alpha_l, beta_l, gamma_l, + alpha_u, beta_u, gamma_u) = MulHelper.get_forward_relaxation( + x.lower, x.upper, y.lower, y.upper, self.opt_stage, getattr(self, 'alpha', None), self._start) if x_lw is None: x_lw = 0 if y_lw is None: y_lw = 0 @@ -243,35 +265,37 @@ def bound_forward_both_perturbed(dim_in, x, y): lw = alpha_l.unsqueeze(1).clamp(min=0) * x_lw + alpha_l.unsqueeze(1).clamp(max=0) * x_uw lw = lw + beta_l.unsqueeze(1).clamp(min=0) * y_lw + beta_l.unsqueeze(1).clamp(max=0) * y_uw - lb = alpha_l.clamp(min=0) * x_lb + alpha_l.clamp(max=0) * x_ub + \ - beta_l.clamp(min=0) * y_lb + beta_l.clamp(max=0) * y_ub + gamma_l + lb = (alpha_l.clamp(min=0) * x_lb + alpha_l.clamp(max=0) * x_ub + + beta_l.clamp(min=0) * y_lb + beta_l.clamp(max=0) * y_ub + gamma_l) uw = alpha_u.unsqueeze(1).clamp(max=0) * x_lw + alpha_u.unsqueeze(1).clamp(min=0) * x_uw uw = uw + beta_u.unsqueeze(1).clamp(max=0) * y_lw + beta_u.unsqueeze(1).clamp(min=0) * y_uw - ub = alpha_u.clamp(max=0) * x_lb + alpha_u.clamp(min=0) * x_ub + \ - beta_u.clamp(max=0) * y_lb + beta_u.clamp(min=0) * y_ub + gamma_u + ub = (alpha_u.clamp(max=0) * x_lb + alpha_u.clamp(min=0) * x_ub + + beta_u.clamp(max=0) * y_lb + beta_u.clamp(min=0) * y_ub + gamma_u) return LinearBound(lw, lb, uw, ub) @staticmethod - def interval_propagate_constant(*v, op=lambda x, const: x * const): - x, y = v[0], v[1] - x_is_const = x[0] is x[1] # FIXME: using better way to represent constant perturbation. - y_is_const = y[0] is y[1] # We should not check the distance between x[0] and x[1]. It's slow! - assert x_is_const or y_is_const - const = x[0] if x_is_const else y[0] - inp_lb = x[0] if y_is_const else y[0] - inp_ub = x[1] if y_is_const else y[1] + def interval_propagate_constant(x, y, op=lambda x, const: x * const): + # x is constant + const = x[0] + inp_lb = y[0] + inp_ub = y[1] pos_mask = (const > 0).to(dtype=inp_lb.dtype) neg_mask = 1. - pos_mask lb = op(inp_lb, const * pos_mask) + op(inp_ub, const * neg_mask) ub = op(inp_ub, const * pos_mask) + op(inp_lb, const * neg_mask) return lb, ub - def interval_propagate(self, *v): - if self.is_constant_op: - return self.interval_propagate_constant(*v) + def interval_propagate(self, x, y): + if self.is_linear_op: + if not self.inputs[0].perturbed: + return self.interval_propagate_constant(x, y) + elif not self.inputs[1].perturbed: + return self.interval_propagate_constant(y, x) + else: + assert False else: - return self.interval_propagate_both_perturbed(*v) + return self.interval_propagate_both_perturbed(x, y) @staticmethod def interval_propagate_both_perturbed(*v): @@ -293,27 +317,30 @@ def interval_propagate_both_perturbed(*v): return lower, upper def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): - for vi in v: - assert isinstance(vi, Tensor), "build solver for BoundMul only with tensors for now" - self.solver_vars = v[0] * v[1] - - -class BoundDiv(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.is_constant_op = False - for inp in inputs: - if isinstance(inp, (BoundConstant, BoundBuffers)): + if isinstance(v[0], Tensor): + self.solver_vars = self.forward(*v) + return + gvar_array = np.array(v[0]) + gvar_array = gvar_array * v[1].cpu().numpy() + self.solver_vars = gvar_array.tolist() + + def update_requires_input_bounds(self): + self.is_linear_op = False + for inp in self.inputs: + if not inp.perturbed: # If any of the two inputs are constant, we do not need input bounds. - # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant. - # FIXME: unify this handling with BoundMul. - self.is_constant_op = True - if self.is_constant_op: + self.is_linear_op = True + if self.is_linear_op: # One input is constant; no bounds required. self.requires_input_bounds = [] + self.splittable = False else: # Both inputs are perturbed. Need relaxation. self.requires_input_bounds = [0, 1] + self.splittable = True + + +class BoundDiv(Bound): def forward(self, x, y): # FIXME (05/11/2022): ad-hoc implementation for layer normalization @@ -330,269 +357,3 @@ def forward(self, x, y): self.x, self.y = x, y return x / y - - def bound_backward(self, last_lA, last_uA, x, y): - if self.is_constant_op: - return BoundMul.bound_backward_constant(last_lA, last_uA, x, y, op=lambda x, const: BoundMul._multiply_by_const(x, 1/const)) - else: - return self.bound_backward_both_perturbed(last_lA, last_uA, x, y) - - def bound_backward_both_perturbed(self, last_lA, last_uA, x, y): - reciprocal, mul, y_r = self._convert_to_mul(x, y) - A, lower_b, upper_b = mul.bound_backward(last_lA, last_uA, x, y_r) - A_y, lower_b_y, upper_b_y = reciprocal.bound_backward(A[1][0], A[1][1], y) - if isinstance(upper_b_y, Tensor) and upper_b_y.ndim == 1: - upper_b_y = upper_b_y.unsqueeze(-1) - if isinstance(lower_b_y, Tensor) and lower_b_y.ndim == 1: - lower_b_y = lower_b_y.unsqueeze(-1) - upper_b = upper_b + upper_b_y - lower_b = lower_b + lower_b_y - return [A[0], A_y[0]], lower_b, upper_b - - def bound_forward(self, dim_in, x, y): - assert not self.is_constant_op - reciprocal, mul, y_r = self._convert_to_mul(x, y) - y_r_linear = reciprocal.bound_forward(dim_in, y) - y_r_linear.lower = y_r.lower - y_r_linear.upper = y_r.upper - return mul.bound_forward(dim_in, x, y_r_linear) - - def interval_propagate(self, *v): - if self.is_constant_op: - return BoundMul.interval_propagate_constant(*v, op=lambda x, const: x / const) - else: - return self.interval_propagate_both_perturbed(*v) - - def interval_propagate(self, *v): - # ad-hoc implementation for layer normalization - """ - Compute bounds for layer normalization - - Lower bound - 1) (x_i - mu) can be negative - - 1 / ( sqrt (1/n * sum_j Lower{(x_j-mu)^2/(x_i-mu)^2} )) - 2) (x_i - mu) cannot be negative - 1 / ( sqrt (1/n * sum_j Upper{(x_j-mu)^2/(x_i-mu)^2} )) - - Lower{(x_j-mu)^2/(x_i-mu)^2} - Lower{sum_j (x_j-mu)^2} / Upper{(x_i-mu)^2} - - Upper{(x_j-mu)^2/(x_i-mu)^2} - Upper{sum_j (x_j-mu)^2} / Lower{(x_i-mu)^2} - """ - if isinstance(self.inputs[1], BoundSqrt): - input = self.inputs[0].inputs[0] - n = input.forward_value.shape[-1] - - h_L, h_U = input.lower, input.upper - - dev_lower = ( - h_L * (1 - 1. / n) - - (h_U.sum(dim=-1, keepdim=True) - h_U) / n - ) - dev_upper = ( - h_U * (1 - 1. / n) - - (h_L.sum(dim=-1, keepdim=True) - h_L) / n - ) - - dev_sqr_lower = (1 - (dev_lower < 0).to(dev_lower.dtype) * (dev_upper > 0).to(dev_lower.dtype)) * \ - torch.min(dev_lower.abs(), dev_upper.abs())**2 - dev_sqr_upper = torch.max(dev_lower.abs(), dev_upper.abs())**2 - - sum_lower = (dev_sqr_lower.sum(dim=-1, keepdim=True) - dev_sqr_lower) / dev_sqr_upper.clamp(min=epsilon) - sqrt_lower = torch.sqrt(1. / n * (sum_lower + 1)) - sum_upper = (dev_sqr_upper.sum(dim=-1, keepdim=True) - dev_sqr_upper) / \ - dev_sqr_lower.clamp(min=epsilon) - sqrt_upper = torch.sqrt(1. / n * (sum_upper + 1)) - - lower = (dev_lower < 0).to(dev_lower.dtype) * (-1. / sqrt_lower) + (dev_lower > 0).to(dev_lower.dtype) * (1. / sqrt_upper) - upper = (dev_upper > 0).to(dev_upper.dtype) * (1. / sqrt_lower) + (dev_upper < 0).to(dev_upper.dtype) * (-1. / sqrt_upper) - - return lower, upper - - x, y = v[0], v[1] - assert (y[0] > 0).all() - return x[0] / y[1], x[1] / y[0] - - def _convert_to_mul(self, x, y): - try: - reciprocal = BoundReciprocal({}, [], 0, None) - mul = BoundMul({}, [], 0, None) - except: - # to make it compatible with previous code - reciprocal = BoundReciprocal(None, {}, [], 0, None) - mul = BoundMul(None, {}, [], 0, None) - reciprocal.output_shape = mul.output_shape = self.output_shape - reciprocal.batch_dim = mul.batch_dim = self.batch_dim - - y_r = copy.copy(y) - if isinstance(y_r, LinearBound): - y_r.lower = 1. / y.upper - y_r.upper = 1. / y.lower - else: - y_r.lower = 1. / y.upper - y_r.upper = 1. / y.lower - return reciprocal, mul, y_r - - def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): - for vi in v: - assert isinstance(vi, Tensor), "build solver for BoundDiv only with tensors for now" - self.solver_vars = v[0] / v[1] - -class BoundAdd(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used. - self.mode = options.get("conv_mode", "matrix") - - def forward(self, x, y): - self.x_shape = x.shape - self.y_shape = y.shape - return x + y - - def bound_backward(self, last_lA, last_uA, x, y): - def _bound_oneside(last_A, w): - if last_A is None: - return None - return self.broadcast_backward(last_A, w) - - uA_x = _bound_oneside(last_uA, x) - uA_y = _bound_oneside(last_uA, y) - lA_x = _bound_oneside(last_lA, x) - lA_y = _bound_oneside(last_lA, y) - return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0 - - def bound_forward(self, dim_in, x, y): - lb, ub = x.lb + y.lb, x.ub + y.ub - - def add_w(x_w, y_w, x_b, y_b): - if x_w is None and y_w is None: - return None - elif x_w is not None and y_w is not None: - return x_w + y_w - elif y_w is None: - return x_w + torch.zeros_like(y_b) - else: - return y_w + torch.zeros_like(x_b) - - lw = add_w(x.lw, y.lw, x.lb, y.lb) - uw = add_w(x.uw, y.uw, x.ub, y.ub) - - return LinearBound(lw, lb, uw, ub) - - def interval_propagate(self, x, y): - assert (not isinstance(y, Tensor)) - return x[0] + y[0], x[1] + y[1] - - def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): - if isinstance(v[0], Tensor) and isinstance(v[1], Tensor): - # constants if both inputs are tensors - self.solver_vars = self.forward(v[0], v[1]) - return - # we have both gurobi vars as inputs - this_layer_shape = self.output_shape - gvar_array1 = np.array(v[0]) - gvar_array2 = np.array(v[1]) - assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:] - - # flatten to create vars and constrs first - gvar_array1 = gvar_array1.reshape(-1) - gvar_array2 = gvar_array2.reshape(-1) - new_layer_gurobi_vars = [] - for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)): - var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0, - vtype=grb.GRB.CONTINUOUS, - name=f'lay{self.name}_{neuron_idx}') - model.addConstr(var == (var1 + var2), name=f'lay{self.name}_{neuron_idx}_eq') - new_layer_gurobi_vars.append(var) - - # reshape to the correct list shape of solver vars - self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist() - model.update() - -class BoundSub(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - # FIXME: This is not the right way to enable patches mode. Instead we must traverse the graph and determine when patches mode needs to be used. - self.mode = options.get("conv_mode", "matrix") - - def forward(self, x, y): - self.x_shape = x.shape - self.y_shape = y.shape - return x - y - - def bound_backward(self, last_lA, last_uA, x, y): - def _bound_oneside(last_A, w, sign=-1): - if last_A is None: - return None - if isinstance(last_A, torch.Tensor): - return self.broadcast_backward(sign * last_A, w) - elif isinstance(last_A, Patches): - if sign == 1: - # Patches shape requires no broadcast. - return last_A - else: - # Multiply by the sign. - return last_A.create_similar(sign * last_A.patches) - else: - raise ValueError(f'Unknown last_A type {type(last_A)}') - - uA_x = _bound_oneside(last_uA, x, sign=1) - uA_y = _bound_oneside(last_uA, y, sign=-1) - lA_x = _bound_oneside(last_lA, x, sign=1) - lA_y = _bound_oneside(last_lA, y, sign=-1) - return [(lA_x, uA_x), (lA_y, uA_y)], 0, 0 - - def bound_forward(self, dim_in, x, y): - lb, ub = x.lb - y.ub, x.ub - y.lb - - def add_w(x_w, y_w, x_b, y_b): - if x_w is None and y_w is None: - return None - elif x_w is not None and y_w is not None: - return x_w + y_w - elif y_w is None: - return x_w + torch.zeros_like(y_b) - else: - return y_w + torch.zeros_like(x_b) - - lw = add_w(x.lw, -y.uw, x.lb, y.lb) - uw = add_w(x.uw, -y.lw, x.ub, y.ub) - - return LinearBound(lw, lb, uw, ub) - - def interval_propagate(self, x, y): - return x[0] - y[1], x[1] - y[0] - - def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): - if isinstance(v[0], Tensor) and isinstance(v[1], Tensor): - # constants if both inputs are tensors - self.solver_vars = self.forward(v[0], v[1]) - return - # we have both gurobi vars as inputs - this_layer_shape = self.output_shape - gvar_array1 = np.array(v[0]) - gvar_array2 = np.array(v[1]) - assert gvar_array1.shape == gvar_array2.shape and gvar_array1.shape == this_layer_shape[1:] - - # flatten to create vars and constrs first - gvar_array1 = gvar_array1.reshape(-1) - gvar_array2 = gvar_array2.reshape(-1) - new_layer_gurobi_vars = [] - for neuron_idx, (var1, var2) in enumerate(zip(gvar_array1, gvar_array2)): - var = model.addVar(lb=-float('inf'), ub=float('inf'), obj=0, - vtype=grb.GRB.CONTINUOUS, - name=f'lay{self.name}_{neuron_idx}') - model.addConstr(var == (var1 - var2), name=f'lay{self.name}_{neuron_idx}_eq') - new_layer_gurobi_vars.append(var) - - # reshape to the correct list shape of solver vars - self.solver_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape[1:]).tolist() - model.update() - -class BoundEqual(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - - def forward(self, x, y): - return x == y diff --git a/auto_LiRPA/operators/clampmult.py b/auto_LiRPA/operators/clampmult.py index 7241fb6..a9c9501 100644 --- a/auto_LiRPA/operators/clampmult.py +++ b/auto_LiRPA/operators/clampmult.py @@ -1,6 +1,5 @@ """Element multiplication with the A matrix based on its sign.""" import torch -import time from typing import Optional, Tuple from torch import Tensor from ..patches import Patches @@ -10,58 +9,52 @@ torch._C._jit_set_profiling_mode(False) -# @torch.jit.script -def _reference_multiply_by_A_signs(A: Tensor, d_pos: Tensor, d_neg: Tensor, - b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]: - """Reference implementation.""" - A_pos = A.clamp(min=0) - A_neg = A.clamp(max=0) - A_new = d_pos * A_pos + d_neg * A_neg - bias_pos = bias_neg = torch.tensor(0.) - if b_pos is not None: - if patches_mode: - bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos) - else: - bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos) - if b_neg is not None: - if patches_mode: - bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg) - else: - bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg) - return A_new, bias_pos + bias_neg - - class ClampedMultiplication(torch.autograd.Function): @staticmethod + @torch.no_grad() @torch.jit.script def clamp_mutiply_forward(A: Tensor, d_pos: Tensor, d_neg: Tensor, - b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool) -> Tuple[Tensor, Tensor]: + b_pos: Optional[Tensor], b_neg: Optional[Tensor], patches_mode: bool, + reduce_bias: bool = False + ) -> Tuple[Tensor, Tensor]: """Forward operations; actually the same as the reference implementation.""" A_pos = A.clamp(min=0) A_neg = A.clamp(max=0) A_new = d_pos * A_pos + d_neg * A_neg - bias_pos = bias_neg = torch.tensor(0.) + bias_pos = bias_neg = torch.zeros( + (), dtype=A_new.dtype, device=A_new.device) if b_pos is not None: - if patches_mode: - bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos) + if not reduce_bias: + bias_pos = A_pos * b_pos else: - bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos) + if patches_mode: + bias_pos = torch.einsum('sb...chw,sb...chw->sb...', A_pos, b_pos) + else: + bias_pos = torch.einsum('sb...,sb...->sb', A_pos, b_pos) if b_neg is not None: - if patches_mode: - bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg) + if not reduce_bias: + bias_neg = A_neg * b_neg else: - bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg) + if patches_mode: + bias_neg = torch.einsum('sb...chw,sb...chw->sb...', A_neg, b_neg) + else: + bias_neg = torch.einsum('sb...,sb...->sb', A_neg, b_neg) return A_new, bias_pos + bias_neg @staticmethod + @torch.no_grad() @torch.jit.script def clamp_mutiply_backward(A: Tensor, d_pos: Tensor, d_neg: Tensor, - b_pos: Optional[Tensor], b_neg: Optional[Tensor], grad_output_A: Tensor, grad_output_bias: Optional[Tensor], - patches_mode: bool) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], None]: - """Improved backward operation. This should be better than the backward function generated by Pytorch.""" + b_pos: Optional[Tensor], b_neg: Optional[Tensor], + grad_output_A: Tensor, grad_output_bias: Optional[Tensor], + ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], Optional[Tensor], + None, None]: + """Improved backward operation. This should be better than the backward + function generated by Pytorch.""" if grad_output_bias is not None: extension_dim = len(A.shape) - len(grad_output_bias.shape) - grad_output_bias = grad_output_bias.view(grad_output_bias.shape + (1, ) * extension_dim) + grad_output_bias = grad_output_bias.view( + grad_output_bias.shape + (1, ) * extension_dim) A_pos_mask = (A >= 0).to(dtype=grad_output_A.dtype) A_neg_mask = 1. - A_pos_mask A_pos_grad_output_A = A_pos_mask * grad_output_A @@ -74,40 +67,51 @@ def clamp_mutiply_backward(A: Tensor, d_pos: Tensor, d_neg: Tensor, gb_neg = A * A_neg_grad_output_bias gb_pos = A * A_pos_grad_output_bias # gA has 4 terms. - gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias + b_neg * A_neg_grad_output_bias + gA = (d_pos * A_pos_grad_output_A + + d_neg * A_neg_grad_output_A + + b_pos * A_pos_grad_output_bias + + b_neg * A_neg_grad_output_bias) elif b_neg is not None and grad_output_bias is not None: A_neg_grad_output_bias = A_neg_mask * grad_output_bias gb_neg = A * A_neg_grad_output_bias gb_pos = None # gA has 3 terms. - gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_neg * A_neg_grad_output_bias + gA = (d_pos * A_pos_grad_output_A + + d_neg * A_neg_grad_output_A + + b_neg * A_neg_grad_output_bias) elif b_pos is not None and grad_output_bias is not None: A_pos_grad_output_bias = A_pos_mask * grad_output_bias gb_pos = A * A_pos_grad_output_bias gb_neg = None # gA has 3 terms. - gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + b_pos * A_pos_grad_output_bias + gA = (d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A + + b_pos * A_pos_grad_output_bias) else: # gA has 2 terms. gA = d_pos * A_pos_grad_output_A + d_neg * A_neg_grad_output_A gb_pos = gb_neg = None - return gA, gd_pos, gd_neg, gb_pos, gb_neg, None + return gA, gd_pos, gd_neg, gb_pos, gb_neg, None, None @staticmethod - def forward(ctx, A, d_pos, d_neg, b_pos, b_neg, patches_mode): + def forward(ctx, A, d_pos, d_neg, b_pos, b_neg, patches_mode, reduce_bias=True): # No need to save the intermediate A_pos, A_neg as they have been fused into the computation. ctx.save_for_backward(A, d_pos, d_neg, b_pos, b_neg) ctx.patches_mode = patches_mode - return ClampedMultiplication.clamp_mutiply_forward(A, d_pos, d_neg, b_pos, b_neg, patches_mode) + ctx.reduce_bias = reduce_bias + return ClampedMultiplication.clamp_mutiply_forward( + A, d_pos, d_neg, b_pos, b_neg, patches_mode, reduce_bias) @staticmethod def backward(ctx, grad_output_A, grad_output_bias): A, d_pos, d_neg, b_pos, b_neg = ctx.saved_tensors - patches_mode = ctx.patches_mode - return ClampedMultiplication.clamp_mutiply_backward(A, d_pos, d_neg, b_pos, b_neg, grad_output_A, grad_output_bias, patches_mode) + assert ctx.reduce_bias + return ClampedMultiplication.clamp_mutiply_backward( + A, d_pos, d_neg, b_pos, b_neg, + grad_output_A, grad_output_bias) -def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'): +def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto', + reduce_bias=True): if isinstance(A, Tensor): if contiguous is True or contiguous == 'auto': # For dense mode, convert d_pos and d_neg to contiguous tensor by default. @@ -119,7 +123,8 @@ def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'): new_A = A.clamp(min=0) * d_pos + A.clamp(max=0) * d_neg new_bias = A.clamp(min=0) * b_pos + A.clamp(max=0) * b_neg return new_A, new_bias - return ClampedMultiplication.apply(A, d_pos, d_neg, b_pos, b_neg, False) + return ClampedMultiplication.apply( + A, d_pos, d_neg, b_pos, b_neg, False, reduce_bias) elif isinstance(A, Patches): if contiguous: # For patches mode, do not convert d_pos and d_neg to contiguous tensor by default. @@ -137,7 +142,8 @@ def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'): b_pos = b_pos.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_pos is not None else None b_neg = b_neg.view(*patches_shape[:2], -1, *patches_shape[-2:]) if b_neg is not None else None # Apply the multiplication based on signs. - A_prod, bias = ClampedMultiplication.apply(patches, d_pos, d_neg, b_pos, b_neg, True) + A_prod, bias = ClampedMultiplication.apply( + patches, d_pos, d_neg, b_pos, b_neg, True, reduce_bias) # prod has shape [out_c, batch_size, out_h, out_w, in_c, H, W] or (unstable_size, batch_size, in_c, H, W) when it is sparse. # For sparse patches the return bias size is (unstable_size, batch). # For regular patches the return bias size is (spec, batch, out_h, out_w). @@ -145,94 +151,3 @@ def multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, contiguous='auto'): A_prod = A_prod.view(*patches_shape) return A.create_similar(A_prod), bias - -def _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=False, n_test=20, warmup=3): - """Benchmarking function.""" - print(f'patches_mode = {patches_mode}, b_pos is {type(b_pos)}, b_neg is {type(b_neg)}') - total_ref = 0. - total_new = 0. - run = ['ref', 'new'] - for i in range(n_test): - ref_time = new_time = 0. - - if 'ref' in run: - torch.cuda.synchronize() - start = time.time() - ref_A, ref_bias = _reference_multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode) - ref_loss = ref_A.sum() + ref_bias.sum() - ref_loss.backward() - torch.cuda.synchronize() - ref_time = time.time() - start - ref_gA = A.grad.detach().clone() - ref_gd_pos = d_pos.grad.detach().clone() - ref_gd_neg = d_neg.grad.detach().clone() - ref_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.) - ref_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.) - A.grad = d_pos.grad = d_neg.grad = None - if b_pos is not None: - b_pos.grad = None - if b_neg is not None: - b_neg.grad = None - del ref_loss - - if 'new' in run: - torch.cuda.synchronize() - start = time.time() - new_A, new_bias = multiply_by_A_signs(A, d_pos, d_neg, b_pos, b_neg, patches_mode) - new_loss = new_A.sum() + new_bias.sum() - new_loss.backward() - torch.cuda.synchronize() - new_time = time.time() - start - new_gA = A.grad.detach().clone() - new_gd_pos = d_pos.grad.detach().clone() - new_gd_neg = d_neg.grad.detach().clone() - new_gb_pos = b_pos.grad.detach().clone() if b_pos is not None else torch.tensor(0.) - new_gb_neg = b_neg.grad.detach().clone() if b_neg is not None else torch.tensor(0.) - A.grad = d_pos.grad = d_neg.grad = None - if b_pos is not None: - b_pos.grad = None - if b_neg is not None: - b_neg.grad = None - del new_loss - - print(f'Loop {i:3d} {"(warmup)" if i < warmup else " "} time ref {ref_time:.5f} new {new_time:.6f} speedup {ref_time / new_time if i >= warmup else float("nan"):.3f}') - if i >= warmup: - total_ref += ref_time - total_new += new_time - - if 'ref' in run and 'new' in run: - A_diff = (ref_A - new_A).abs().sum().item() / ref_A.abs().sum().item() - gA_diff = (ref_gA - new_gA).abs().sum().item() / ref_gA.abs().sum().item() - bias_diff = (ref_bias - new_bias).abs().sum().item() / (ref_bias.abs().sum().item() + 1e-10) - gd_pos_diff = (ref_gd_pos - new_gd_pos).abs().sum().item() / ref_gd_pos.abs().sum().item() - gd_neg_diff = (ref_gd_neg - new_gd_neg).abs().sum().item() / ref_gd_neg.abs().sum().item() - gb_pos_diff = (ref_gb_pos - new_gb_pos).abs().sum().item() / (ref_gb_pos.abs().sum().item() + 1e-10) - gb_neg_diff = (ref_gb_neg - new_gb_neg).abs().sum().item() / (ref_gb_neg.abs().sum().item() + 1e-10) - print(f' diff {A_diff} {gA_diff} {bias_diff} {gd_pos_diff} {gd_neg_diff} {gb_pos_diff} {gb_neg_diff}') - assert A_diff < 1e-6 and bias_diff < 1e-6 and gA_diff < 1e-6 and gd_pos_diff < 1e-6 and gd_neg_diff < 1e-6 - assert gb_pos_diff < 1e-6 and gb_neg_diff < 1e-6 - - - avg_ref_time = total_ref / (n_test - warmup) - avg_new_time = total_new / (n_test - warmup) - print(f'Avg. time: reference {avg_ref_time:.5f} new {avg_new_time:.6f} speedup {avg_ref_time / avg_new_time:.3f}') - - -if __name__ == '__main__': - for patches_mode in [True, False]: - if patches_mode: - shape = (256, 8, 8, 8, 16, 32) - else: - shape = (256, 8, 128, 256) - A = torch.randn(shape, device='cuda', requires_grad=True) - d_pos = torch.randn(shape, device='cuda', requires_grad=True) - d_neg = torch.randn(shape, device='cuda', requires_grad=True) - b_pos = torch.randn(shape, device='cuda', requires_grad=True) - b_neg = torch.randn(shape, device='cuda', requires_grad=True) - _speed_test(A, d_pos, d_neg, None, None, patches_mode=patches_mode) - _speed_test(A, d_pos, d_neg, None, b_neg, patches_mode=patches_mode) - _speed_test(A, d_pos, d_neg, b_pos, None, patches_mode=patches_mode) - _speed_test(A, d_pos, d_neg, b_pos, b_neg, patches_mode=patches_mode) - print('Press Enter key to continue.') - input() - del A, d_pos, d_neg, b_pos, b_neg diff --git a/auto_LiRPA/operators/constant.py b/auto_LiRPA/operators/constant.py index de11a0e..ffa92e3 100644 --- a/auto_LiRPA/operators/constant.py +++ b/auto_LiRPA/operators/constant.py @@ -1,16 +1,23 @@ """ Constant operators, including operators that are usually fixed nodes and not perturbed """ from .base import * + class BoundConstant(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.value = attr['value'].to(self.device) self.use_default_ibp = True + def __repr__(self): + if self.value.numel() == 1: + return f'BoundConstant(name={self.name}, value={self.value})' + else: + return super().__repr__() + def forward(self): return self.value.to(self.device) - def bound_backward(self, last_lA, last_uA): + def bound_backward(self, last_lA, last_uA, **kwargs): def _bound_oneside(A): if A is None: return 0.0 @@ -40,14 +47,11 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") class BoundPrimConstant(Bound): - def __init__(self, attr, input, output_index, options): - super().__init__(attr, input, output_index, options) - def forward(self): return torch.tensor([], device=self.device) class BoundConstantOfShape(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.value = attr['value'].to(self.device) @@ -56,7 +60,7 @@ def forward(self, x): self.from_input = True return self.value.expand(*list(x)) - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): if last_lA is not None: lower_sum_b = last_lA * self.value while lower_sum_b.ndim > 2: @@ -81,15 +85,14 @@ def bound_forward(self, dim_in, x): def interval_propagate(self, *v): self.x = v[0][0] - size = int(v[0][0].item()) if isinstance(v[0][0], Tensor) else v[0][0] - value = torch.ones(size, device=self.device) * self.value + value = torch.ones(tuple(v[0][0]), device=self.device) * self.value return value, value def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): self.solver_vars = self.forward(v) class BoundRange(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.device = attr['device'] @@ -100,7 +103,7 @@ def forward(self, start, end, step): return torch.arange(start, end, step, device=self.device) class BoundATenDiag(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.device = attr['device'] @@ -111,7 +114,7 @@ def interval_propagate(self, *v): return Interval.make_interval(torch.diag(v[0][0], v[1][0]), torch.diag(v[0][1], v[1][0]), v[0]) class BoundATenDiagonal(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.device = attr['device'] diff --git a/auto_LiRPA/operators/convolution.py b/auto_LiRPA/operators/convolution.py index 7ffbfc5..241fd65 100644 --- a/auto_LiRPA/operators/convolution.py +++ b/auto_LiRPA/operators/convolution.py @@ -2,19 +2,30 @@ from .base import * import numpy as np from .solver_utils import grb -from ..patches import unify_shape, compute_patches_stride_padding, is_shape_used +from ..patches import unify_shape, compute_patches_stride_padding, is_shape_used, create_valid_mask from .gradient_modules import Conv2dGrad +EPS = 1e-2 class BoundConv(Bound): - def __init__(self, attr, inputs, output_index, options): - assert (attr['pads'][0] == attr['pads'][2]) - assert (attr['pads'][1] == attr['pads'][3]) - + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) + if len(attr['kernel_shape']) == 1: + # for 1d conv + assert (attr['pads'][0] == attr['pads'][1]) + self.padding = [attr['pads'][0]] + self.F_conv = F.conv1d + self.conv_dim = 1 + else: + # for 2d conv + assert (attr['pads'][0] == attr['pads'][2]) + assert (attr['pads'][1] == attr['pads'][3]) + self.padding = [attr['pads'][0], attr['pads'][1]] + self.F_conv = F.conv2d + self.conv_dim = 2 + self.stride = attr['strides'] - self.padding = [attr['pads'][0], attr['pads'][1]] self.dilation = attr['dilations'] self.groups = attr['group'] if len(inputs) == 3: @@ -32,10 +43,12 @@ def __init__(self, attr, inputs, output_index, options): def forward(self, *x): # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias bias = x[2] if self.has_bias else None - output = F.conv2d(x[0], x[1], bias, self.stride, self.padding, self.dilation, self.groups) + + output = self.F_conv(x[0], x[1], bias, self.stride, self.padding, self.dilation, self.groups) + return output - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): if self.is_input_perturbed(1): raise NotImplementedError( 'Weight perturbation for convolution layers has not been implmented.') @@ -48,67 +61,56 @@ def _bound_oneside(last_A): return None, 0 if type(last_A) is OneHotC: # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead. - shape = last_A.shape # [spec, batch, C, H, W] - dim = int(prod(shape[2:])) - dense_last_A = torch.zeros( - size=(shape[0], shape[1], dim), device=last_A.device, dtype=weight.dtype) - # last_A.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A. - # last_A.coeffs is the value of the non-zero element. - dense_last_A = torch.scatter( - dense_last_A, dim=2, index=last_A.index.unsqueeze(-1), - src=last_A.coeffs.unsqueeze(-1)) - # We created a large A matrix and it will be handled below. - last_A = dense_last_A.view(shape[0], shape[1], *shape[2:]) + last_A = onehotc_to_dense(last_A, dtype=weight.dtype) if type(last_A) == Tensor: shape = last_A.size() # when (W−F+2P)%S != 0, construct the output_padding - output_padding0 = ( - int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 * - self.padding[0] - 1 - (int(weight.size()[2] - 1) * self.dilation[0])) - output_padding1 = ( - int(self.input_shape[3]) - (int(self.output_shape[3]) - 1) * self.stride[1] + 2 * - self.padding[1] - 1 - (int(weight.size()[3] - 1) * self.dilation[0])) - next_A = F.conv_transpose2d( - last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None, - stride=self.stride, padding=self.padding, dilation=self.dilation, - groups=self.groups, output_padding=(output_padding0, output_padding1)) + if self.conv_dim == 2: + output_padding0 = ( + int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 * + self.padding[0] - 1 - (int(weight.size()[2] - 1) * self.dilation[0])) + output_padding1 = ( + int(self.input_shape[3]) - (int(self.output_shape[3]) - 1) * self.stride[1] + 2 * + self.padding[1] - 1 - (int(weight.size()[3] - 1) * self.dilation[0])) + next_A = F.conv_transpose2d( + last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None, + stride=self.stride, padding=self.padding, dilation=self.dilation, + groups=self.groups, output_padding=(output_padding0, output_padding1)) + else: + # for 1d conv, we use conv_transpose1d() + output_padding = ( + int(self.input_shape[2]) - (int(self.output_shape[2]) - 1) * self.stride[0] + 2 * + self.padding[0] - 1 - (int(weight.size()[2] - 1) * self.dilation[0])) + next_A = F.conv_transpose1d( + last_A.reshape(shape[0] * shape[1], *shape[2:]), weight, None, + stride=self.stride, padding=self.padding, dilation=self.dilation, + groups=self.groups, output_padding=output_padding) + next_A = next_A.view(shape[0], shape[1], *next_A.shape[1:]) if self.has_bias: # sum_bias = (last_A.sum((3, 4)) * x[2].lower).sum(2) - sum_bias = torch.einsum('sbchw,c->sb', last_A, x[2].lower) + sum_bias = torch.einsum('sbc...,c->sb', last_A, x[2].lower) else: sum_bias = 0 return next_A, sum_bias elif type(last_A) == Patches: # Here we build and propagate a Patch object with (patches, stride, padding) + assert self.conv_dim == 2, 'Patches mode not supports conv1d so far.' assert type(last_A) == Patches if last_A.identity == 0: # FIXME (09/20): Don't call it relu_followed. Instead, make this a property of A, called "padded" and propagate this property. if not self.relu_followed: # The last_A.patches was not padded, so we need to pad them here. # If this Conv layer is followed by a ReLU layer, then the padding was already handled there and there is no need to pad again. - one_d = torch.ones( - tuple(1 for i in self.output_shape[1:]), - device=last_A.patches.device, dtype=weight.dtype - ).expand(self.output_shape[1:]) - # Add batch dimension. - one_d = one_d.unsqueeze(0) - # After unfolding, the shape is (1, out_h, out_w, in_c, h, w) - one_d_unfolded = inplace_unfold( - one_d, kernel_size=last_A.patches.shape[-2:], - stride=last_A.stride, padding=last_A.padding, - inserted_zeros=last_A.inserted_zeros, - output_padding=last_A.output_padding) - if last_A.unstable_idx is not None: - # Move out_h, out_w dimension to the front for easier selection. - one_d_unfolded_r = one_d_unfolded.permute(1, 2, 0, 3, 4, 5) - # for sparse patches the shape is (unstable_size, batch, in_c, h, w). - # Batch size is 1 so no need to select here. - one_d_unfolded_r = one_d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] - else: - # Append the spec dimension. - one_d_unfolded_r = one_d_unfolded.unsqueeze(0) + one_d_unfolded_r = create_valid_mask(self.output_shape, last_A.patches.device, + weight.dtype, + last_A.patches.shape[-2:], + last_A.stride, + last_A.inserted_zeros, + last_A.padding, + last_A.output_padding, + last_A.unstable_idx if last_A.unstable_idx else None) patches = last_A.patches * one_d_unfolded_r else: patches = last_A.patches @@ -185,12 +187,12 @@ def _bound_oneside(last_A): sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1) A_matrix = A_matrix.transpose(0,1) # Spec dimension at the front. return A_matrix, sum_bias - # print(f'Conv returns patches with size={pieces.size()}, stride={stride}, padding={padding}, inserted_zeros={inserted_zeros}, output_padding={output_padding}') - return Patches(pieces, stride, padding, pieces.shape, - unstable_idx=last_A.unstable_idx, - output_shape=last_A.output_shape, - inserted_zeros=last_A.inserted_zeros, - output_padding=output_padding), sum_bias + new_patches = last_A.create_similar( + pieces, stride=stride, padding=padding, output_padding=output_padding, + identity=0, input_shape=self.input_shape) + # if last_A is last_lA: + # print(f'Conv : start_node {kwargs["start_node"].name} layer {self.name} {new_patches}') + return new_patches, sum_bias else: raise NotImplementedError() @@ -212,8 +214,8 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") out_lbs, out_ubs = None, None if hasattr(self, "lower"): # self.lower shape (1,8,16,16) - out_lbs = self.lower.cpu().numpy() - out_ubs = self.upper.cpu().numpy() + out_lbs = self.lower.detach().cpu().numpy() + out_ubs = self.upper.detach().cpu().numpy() # current layer weight (8,3,4,4) this_layer_weight = v[1].detach().cpu().numpy() @@ -228,55 +230,102 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") new_layer_gurobi_vars = [] new_layer_gurobi_constrs = [] + # precompute row and column index mappings + + # compute row mapping: from current row to input rows + # vectorization of following code: + # for out_row_idx in range(this_layer_shape[2]): + # ker_row_min, ker_row_max = 0, weight_shape2 + # in_row_idx_min = -padding0 + stride0 * out_row_idx + # in_row_idx_max = in_row_idx_min + weight_shape2 - 1 + # if in_row_idx_min < 0: + # ker_row_min = -in_row_idx_min + # if in_row_idx_max >= pre_layer_shape[2]: + # ker_row_max = ker_row_max - in_row_idx_max + pre_layer_shape[2] - 1 + # in_row_idx_min, in_row_idx_max = max(in_row_idx_min, 0), min(in_row_idx_max, + # pre_layer_shape[2] - 1) + in_row_idx_mins = np.arange(this_layer_shape[2]) * stride0 - padding0 + in_row_idx_maxs = in_row_idx_mins + weight_shape2 - 1 + ker_row_mins = np.zeros(this_layer_shape[2], dtype=int) + ker_row_maxs = np.ones(this_layer_shape[2], dtype=int) * weight_shape2 + ker_row_mins[in_row_idx_mins < 0] = -in_row_idx_mins[in_row_idx_mins < 0] + ker_row_maxs[in_row_idx_maxs >= pre_layer_shape[2]] = \ + ker_row_maxs[in_row_idx_maxs >= pre_layer_shape[2]] - in_row_idx_maxs[in_row_idx_maxs >= pre_layer_shape[2]]\ + + pre_layer_shape[2] - 1 + in_row_idx_mins = np.maximum(in_row_idx_mins, 0) + in_row_idx_maxs = np.minimum(in_row_idx_maxs, pre_layer_shape[2] - 1) + + # compute column mapping: from current column to input columns + # vectorization of following code: + # for out_col_idx in range(this_layer_shape[3]): + # ker_col_min, ker_col_max = 0, weight_shape3 + # in_col_idx_min = -padding1 + stride1 * out_col_idx + # in_col_idx_max = in_col_idx_min + weight_shape3 - 1 + # if in_col_idx_min < 0: + # ker_col_min = -in_col_idx_min + # if in_col_idx_max >= pre_layer_shape[3]: + # ker_col_max = ker_col_max - in_col_idx_max + pre_layer_shape[3] - 1 + # in_col_idx_min, in_col_idx_max = max(in_col_idx_min, 0), min(in_col_idx_max, + # pre_layer_shape[3] - 1) + in_col_idx_mins = np.arange(this_layer_shape[3]) * stride1 - padding1 + in_col_idx_maxs = in_col_idx_mins + weight_shape3 - 1 + ker_col_mins = np.zeros(this_layer_shape[3], dtype=int) + ker_col_maxs = np.ones(this_layer_shape[3], dtype=int) * weight_shape3 + ker_col_mins[in_col_idx_mins < 0] = -in_col_idx_mins[in_col_idx_mins < 0] + ker_col_maxs[in_col_idx_maxs >= pre_layer_shape[3]] = \ + ker_col_maxs[in_col_idx_maxs >= pre_layer_shape[3]] - in_col_idx_maxs[in_col_idx_maxs >= pre_layer_shape[3]]\ + + pre_layer_shape[3] - 1 + in_col_idx_mins = np.maximum(in_col_idx_mins, 0) + in_col_idx_maxs = np.minimum(in_col_idx_maxs, pre_layer_shape[3] - 1) + neuron_idx = 0 for out_chan_idx in range(this_layer_shape[1]): out_chan_vars = [] for out_row_idx in range(this_layer_shape[2]): out_row_vars = [] + + # get row index range from precomputed arrays + ker_row_min, ker_row_max = ker_row_mins[out_row_idx], ker_row_maxs[out_row_idx] + in_row_idx_min, in_row_idx_max = in_row_idx_mins[out_row_idx], in_row_idx_maxs[out_row_idx] + for out_col_idx in range(this_layer_shape[3]): - # print(this_layer_bias.shape, out_chan_idx, out_lbs.size(1)) - lin_expr = 0 - if self.has_bias: - lin_expr = this_layer_bias[out_chan_idx] - for in_chan_idx in range(this_layer_weight.shape[1]): + # get col index range from precomputed arrays + ker_col_min, ker_col_max = ker_col_mins[out_col_idx], ker_col_maxs[out_col_idx] + in_col_idx_min, in_col_idx_max = in_col_idx_mins[out_col_idx], in_col_idx_maxs[out_col_idx] - # new version of conv layer for building mip by skipping kernel loops - ker_row_min, ker_row_max = 0, weight_shape2 - in_row_idx_min = -padding0 + stride0 * out_row_idx - in_row_idx_max = in_row_idx_min + weight_shape2 - 1 - if in_row_idx_min < 0: - ker_row_min = -in_row_idx_min - if in_row_idx_max >= pre_layer_shape[2]: - ker_row_max = ker_row_max - in_row_idx_max + pre_layer_shape[2] -1 - in_row_idx_min, in_row_idx_max = max(in_row_idx_min, 0), min(in_row_idx_max, pre_layer_shape[2] - 1) - - ker_col_min, ker_col_max = 0, weight_shape3 - in_col_idx_min = -padding1 + stride1 * out_col_idx - in_col_idx_max = in_col_idx_min + weight_shape3 - 1 - if in_col_idx_min < 0: - ker_col_min = -in_col_idx_min - if in_col_idx_max >= pre_layer_shape[3]: - ker_col_max = ker_col_max - in_col_idx_max + pre_layer_shape[3] -1 - in_col_idx_min, in_col_idx_max = max(in_col_idx_min, 0), min(in_col_idx_max, pre_layer_shape[3] - 1) + # init linear expression + lin_expr = this_layer_bias[out_chan_idx] if self.has_bias else 0 - coeffs = this_layer_weight[out_chan_idx, in_chan_idx, ker_row_min:ker_row_max, ker_col_min:ker_col_max].reshape(-1) + # init linear constraint LHS implied by the conv operation + for in_chan_idx in range(this_layer_weight.shape[1]): + coeffs = this_layer_weight[out_chan_idx, in_chan_idx, ker_row_min:ker_row_max, ker_col_min:ker_col_max].reshape(-1) gvars = gvars_array[in_chan_idx, in_row_idx_min:in_row_idx_max+1, in_col_idx_min:in_col_idx_max+1].reshape(-1) if solver_pkg == 'gurobi': lin_expr += grb.LinExpr(coeffs, gvars) else: - # lin_expr += coeffs@gvars - for i in range(len(coeffs)): try: lin_expr += coeffs[i] * gvars[i] except TypeError: lin_expr += coeffs[i] * gvars[i].var - + # init potential lb and ub, which helps solver to finish faster out_lb = out_lbs[0, out_chan_idx, out_row_idx, out_col_idx] if out_lbs is not None else -float('inf') out_ub = out_ubs[0, out_chan_idx, out_row_idx, out_col_idx] if out_ubs is not None else float('inf') + if out_ub - out_lb < EPS: + """ + If the inferred lb and ub are too close, it could lead to floating point disagreement + between solver's inferred lb and ub constraints and the computed ones from ab-crown. + Such disagreement can lead to "infeasible" result from the solver for feasible problem. + To avoid so, we relax the box constraints. + This should not affect the solver's result correctness, + since the tighter lb and ub can be inferred by the solver. + """ + out_lb, out_ub = (out_lb + out_ub - EPS) / 2., (out_lb + out_ub + EPS) / 2. + + # add the output var and constraint var = model.addVar(lb=out_lb, ub=out_ub, obj=0, vtype=grb.GRB.CONTINUOUS, # name=f'lay{layer_idx}_[{out_chan_idx}, {out_row_idx}, {out_col_idx}]') @@ -306,16 +355,17 @@ def interval_propagate(self, *v, C=None): weight = v[1][0] bias = v[2][0] if self.has_bias else None - if norm == np.inf: + if norm == torch.inf: mid = (h_U + h_L) / 2.0 diff = (h_U - h_L) / 2.0 weight_abs = weight.abs() - deviation = F.conv2d(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups) + deviation = self.F_conv(diff, weight_abs, None, self.stride, self.padding, self.dilation, self.groups) elif norm > 0: norm, eps = Interval.get_perturbation(v[0]) # L2 norm, h_U and h_L are the same. mid = h_U # TODO: padding + assert not isinstance(eps, torch.Tensor) or eps.numel() == 1 deviation = torch.mul(weight, weight).sum((1, 2, 3)).sqrt() * eps deviation = deviation.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) else: # Here we calculate the L0 norm IBP bound using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020] @@ -326,14 +376,14 @@ def interval_propagate(self, *v, C=None): deviation = torch.sum(torch.topk(weight_sum.view(weight_sum.shape[0], -1), k)[0], dim=1) * ratio if self.has_bias: - center = F.conv2d(mid, weight, v[2][0], self.stride, self.padding, self.dilation, self.groups) + center = self.F_conv(mid, weight, v[2][0], self.stride, self.padding, self.dilation, self.groups) else: - center = F.conv2d(mid, weight, None, self.stride, self.padding, self.dilation, self.groups) + center = self.F_conv(mid, weight, None, self.stride, self.padding, self.dilation, self.groups) ss = center.shape deviation = deviation.repeat(ss[2] * ss[3]).view(-1, ss[1]).t().view(ss[1], ss[2], ss[3]) - center = F.conv2d(mid, weight, bias, self.stride, self.padding, self.dilation, self.groups) + center = self.F_conv(mid, weight, bias, self.stride, self.padding, self.dilation, self.groups) upper = center + deviation lower = center - deviation @@ -357,12 +407,12 @@ def conv2d(input, weight, bias, stride, padding, dilation, groups): if input.device != torch.device('cpu') and input.shape[0] > max_batch_size: ret = [] for i in range((input.shape[0] + max_batch_size - 1) // max_batch_size): - ret.append(F.conv2d( + ret.append(self.F_conv( input[i*max_batch_size:(i+1)*max_batch_size], weight, bias, stride, padding, dilation, groups)) return torch.cat(ret, dim=0) else: - return F.conv2d(input, weight, bias, stride, padding, dilation, groups) + return self.F_conv(input, weight, bias, stride, padding, dilation, groups) w_new = conv2d( w.reshape(shape_wconv), weight, None, self.stride, self.padding, self.dilation, self.groups) @@ -386,16 +436,16 @@ def bound_forward(self, dim_in, *x): weight_abs = weight.abs() shape = mid_w.shape shape_wconv = [shape[0] * shape[1]] + list(shape[2:]) - deviation_w = F.conv2d( + deviation_w = self.F_conv( diff_w.reshape(shape_wconv), weight_abs, None, self.stride, self.padding, self.dilation, self.groups) - deviation_b = F.conv2d( + deviation_b = self.F_conv( diff_b, weight_abs, None, self.stride, self.padding, self.dilation, self.groups) - center_w = F.conv2d( + center_w = self.F_conv( mid_w.reshape(shape_wconv), weight, None, self.stride, self.padding, self.dilation, self.groups) - center_b = F.conv2d( + center_b = self.F_conv( mid_b, weight, bias, self.stride, self.padding, self.dilation, self.groups) deviation_w = deviation_w.reshape(shape[0], -1, *deviation_w.shape[1:]) @@ -413,9 +463,12 @@ def build_gradient_node(self, grad_upstream): self.dilation, self.groups) return node_grad, (grad_upstream,), [] + def update_requires_input_bounds(self): + self._check_weight_perturbation() + class BoundConvTranspose(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) assert (attr['pads'][0] == attr['pads'][2]) assert (attr['pads'][1] == attr['pads'][3]) @@ -425,17 +478,19 @@ def __init__(self, attr, inputs, output_index, options): self.dilation = attr['dilations'] self.groups = attr['group'] self.output_padding = [attr.get('output_padding', [0, 0])[0], attr.get('output_padding', [0, 0])[1]] + assert len(attr['kernel_shape']) == 2 # 2d transposed convolution. if len(inputs) == 3: self.has_bias = True else: self.has_bias = False self.mode = options.get("conv_mode", "matrix") assert self.output_padding == [0, 0] - assert self.padding == [0, 0] assert self.dilation == [1, 1] assert self.stride[0] == self.stride[1] assert self.groups == 1 + self.F_convtranspose = F.conv_transpose2d + def forward(self, *x): # x[0]: input, x[1]: weight, x[2]: bias if self.has_bias bias = x[2] if self.has_bias else None @@ -443,7 +498,7 @@ def forward(self, *x): return output - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): if self.is_input_perturbed(1): raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") @@ -456,14 +511,7 @@ def _bound_oneside(last_A): return None, 0 if type(last_A) is OneHotC: # Conv layer does not support the OneHotC fast path. We have to create a dense matrix instead. - shape = last_A.shape # [spec, batch, C, H, W] - dim = int(prod(shape[2:])) - dense_last_A = torch.zeros(size=(shape[0], shape[1], dim), device=last_A.device, dtype=weight.dtype) - # last_A.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A. - # last_A.coeffs is the value of the non-zero element. - dense_last_A = torch.scatter(dense_last_A, dim=2, index=last_A.index.unsqueeze(-1), src=last_A.coeffs.unsqueeze(-1)) - # We created a large A matrix and it will be handled below. - last_A = dense_last_A.view(shape[0], shape[1], *shape[2:]) + last_A = onehotc_to_dense(last_A, dtype=weight.dtype) if type(last_A) == Tensor: shape = last_A.size() @@ -495,7 +543,12 @@ def _bound_oneside(last_A): flattened_patches = patches.reshape(-1, patches.size(-3), patches.size(-2), patches.size(-1)) # Merge patches with this layer's weights. Weight must be flipped here; and if stride != 1, we must insert zeros in the input image. # For conv_transpose2d, the weight matrix is in the (in, out, k, k) shape. - pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=last_A.inserted_zeros + 1) + # pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=self.stride) + # pieces = F.conv_transpose2d(flattened_patches, weight.transpose(0,1).flip(-1,-2), stride=last_A.inserted_zeros + 1) + # Use padding in conv_transposed2d directly. + pieces = F.conv_transpose2d( + # Transpose because the weight has in_channel before out_channel. + flattened_patches, insert_zeros(weight.transpose(0,1).flip(-1,-2), last_A.inserted_zeros)) # New patch size: (out_c, batch, out_h, out_w, c, h, w) or (unstable_size, batch, c, h, w). pieces = pieces.view(*patches.shape[:-3], pieces.size(-3), pieces.size(-2), pieces.size(-1)) @@ -512,23 +565,23 @@ def _bound_oneside(last_A): sum_bias = x[2].lower.view(-1, 1, 1, 1).expand(-1, *last_A.shape[1:4]) else: raise NotImplementedError() - padding = last_A.padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom) + patches_padding = last_A.padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom) output_padding = last_A.output_padding if last_A is not None else (0, 0, 0, 0) # (left, right, top, bottom) inserted_zeros = last_A.inserted_zeros - assert self.padding == [0, 0] assert self.stride[0] == self.stride[1] # Unify the shape to 4-tuple. output_padding = unify_shape(output_padding) - padding = unify_shape(padding) + patches_padding = unify_shape(patches_padding) this_stride = unify_shape(self.stride) this_padding = unify_shape(self.padding) - # Compute new padding. - padding = tuple(p + (weight.size(3 - j//2) - 1) for j, p in enumerate(padding)) + # Compute new padding. Due to the shape flip during merging, we need to check the string/size on the dimension 3 - j. + # TODO: testing for asymmetric shapes. + padding = tuple(p * (inserted_zeros + 1) + (weight.size(3 - j//2) - 1) for j, p in enumerate(patches_padding)) # Compute new output padding - output_padding = tuple(p * this_stride[j] + this_padding[j] for j, p in enumerate(output_padding)) + output_padding = tuple(p * (inserted_zeros + 1) + this_padding[j] for j, p in enumerate(output_padding)) # When we run insert_zeros, it's missing the right most column and the bottom row. # padding = (padding[0], padding[1] + inserted_zeros, padding[2], padding[3] + inserted_zeros) @@ -548,8 +601,14 @@ def _bound_oneside(last_A): sum_bias = sum_bias.reshape(sum_bias.size(0), -1).transpose(0,1) A_matrix = A_matrix.transpose(0,1) # Spec dimension at the front. return A_matrix, sum_bias - return Patches(pieces, last_A.stride, padding, pieces.shape, unstable_idx=last_A.unstable_idx, - output_shape=last_A.output_shape, inserted_zeros=inserted_zeros, output_padding=output_padding), sum_bias + new_patches = last_A.create_similar( + pieces, padding=padding, inserted_zeros=inserted_zeros, output_padding=output_padding, + input_shape=self.input_shape) + # if last_A is last_lA: + # print(f'ConvT input : start_node {kwargs["start_node"].name} layer {self.name} {last_lA}') + # print(f'ConvT layer : padding {self.padding} stride {self.stride} kernel {list(weight.shape[-2:])} input {list(self.input_shape)} output {list(self.output_shape)}') + # print(f'ConvT output: start_node {kwargs["start_node"].name} layer {self.name} {new_patches}') + return new_patches, sum_bias else: raise NotImplementedError() @@ -568,7 +627,7 @@ def interval_propagate(self, *v, C=None): weight = v[1][0] bias = v[2][0] if self.has_bias else None - if norm == np.inf: + if norm == torch.inf: mid = (h_U + h_L) / 2.0 diff = (h_U - h_L) / 2.0 weight_abs = weight.abs() @@ -590,8 +649,44 @@ def interval_propagate(self, *v, C=None): lower = center - deviation return lower, upper + def bound_forward(self, dim_in, *x): + if self.is_input_perturbed(1) or self.is_input_perturbed(2): + raise NotImplementedError("Weight perturbation for convolution layers has not been implmented.") + + weight = x[1].lb + bias = x[2].lb if self.has_bias else None + x = x[0] + + mid_w = (x.lw + x.uw) / 2 + mid_b = (x.lb + x.ub) / 2 + diff_w = (x.uw - x.lw) / 2 + diff_b = (x.ub - x.lb) / 2 + weight_abs = weight.abs() + shape = mid_w.shape + shape_wconv = [shape[0] * shape[1]] + list(shape[2:]) + deviation_w = self.F_convtranspose( + diff_w.reshape(shape_wconv), weight_abs, None, output_padding=self.output_padding, + stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + deviation_b = self.F_convtranspose( + diff_b, weight_abs, None, output_padding=self.output_padding, + stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + center_w = self.F_convtranspose( + mid_w.reshape(shape_wconv), weight, output_padding=self.output_padding, + stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + center_b = self.F_convtranspose( + mid_b, weight, bias, output_padding=self.output_padding, + stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + deviation_w = deviation_w.reshape(shape[0], -1, *deviation_w.shape[1:]) + center_w = center_w.reshape(shape[0], -1, *center_w.shape[1:]) + + return LinearBound( + lw = center_w - deviation_w, + lb = center_b - deviation_b, + uw = center_w + deviation_w, + ub = center_b + deviation_b) + class BoundPad(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) if hasattr(attr, 'pads'): self.padding = attr['pads'][2:4] + attr['pads'][6:8] @@ -614,7 +709,7 @@ def interval_propagate(self, *v): l, u = zip(*v) return Interval.make_interval(self.forward(*l), self.forward(*u), v[0]) - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): # TODO: padding for 3-D or more dimensional inputs. left, right, top, bottom = self.padding def _bound_oneside(last_A): diff --git a/auto_LiRPA/operators/cut_ops.py b/auto_LiRPA/operators/cut_ops.py index e4337ae..c86f7bc 100644 --- a/auto_LiRPA/operators/cut_ops.py +++ b/auto_LiRPA/operators/cut_ops.py @@ -192,7 +192,7 @@ def jit_arelu_lA(last_lA, lower, upper, beta_mm_coeffs, unstable_or_cut_index, u tao, pi = tao.clamp(min=0.), pi.clamp(min=0.) tao, pi = torch.min(tao, nu_hat_pos), torch.min(pi, nu_hat_pos) new_upper_d = pi / (pi + tao + 1e-10) - # need to customize the upper bound slope and lbias for (1) unstable relus and + # need to customize the upper bound slope and lbias for (1) unstable relus and # (2) relus that are used with upper boundary relaxation # original upper bound slope is u/(u-l) also equal to pi/(pi+tao) if no beta_mm_coeffs[0] # now the upper bound slope should be pi/(p+tao) updated with beta_mm_coeffs[0] @@ -294,7 +294,7 @@ def arelu_cut(self, start_node, layer_name, last_lA, last_uA, lower_d, upper_d, beta_mm_coeffs = self.general_beta_coeffs_mm(general_beta, arelu_coeffs, A, current_layer_shape) # unstable_this_layer = torch.logical_and(x.lower < 0, x.upper > 0).unsqueeze(0) # I is the unstable index in this relu layer: (batch, *layer shape) - # if there is node in cut constraint that is stable, also need to count its effect + # if there is node in cut constraint that is stable, also need to count its effect # self.arelu_coeffs: (num_constrs, flattened current layer) unstable_or_cut_index = I.logical_or(arelu_coeffs.sum(0).view(I[0:1].shape) != 0) @@ -389,7 +389,7 @@ def arelu_cut(self, start_node, layer_name, last_lA, last_uA, lower_d, upper_d, # assert ((tao + pi - nu_hat_pos).abs()*unstable_or_cut_index).max() <= 1e-5, "pi+tao should always be the same as nu_hat_pos" - # # need to customize the upper bound slope and lbias for (1) unstable relus and + # # need to customize the upper bound slope and lbias for (1) unstable relus and # # (2) relus that are used with upper boundary relaxation # # original upper bound slope is u/(u-l) also equal to pi/(pi+tao) if no beta_mm_coeffs[0] # # now the upper bound slope should be pi/(p+tao) updated with beta_mm_coeffs[0] @@ -585,4 +585,4 @@ def _maybe_unfold(d_tensor, last_A): # For regular patches, the shape after unfold is (spec, batch, out_h, out_w, in_c, H, W). if d_unfolded_r.ndim != last_A.patches.ndim: d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4) - return d_unfolded_r + return d_unfolded_r \ No newline at end of file diff --git a/auto_LiRPA/operators/dropout.py b/auto_LiRPA/operators/dropout.py index d4601a6..3de2013 100644 --- a/auto_LiRPA/operators/dropout.py +++ b/auto_LiRPA/operators/dropout.py @@ -1,7 +1,7 @@ from .base import * class BoundDropout(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) if 'ratio' in attr: self.ratio = attr['ratio'] @@ -21,13 +21,14 @@ def forward(self, *inputs): if self.dynamic: # Inputs: data, ratio (optional), training_mode (optional) # We assume ratio must exist in the inputs. - # We ignore training_mode, but will use self.training which can be + # We ignore training_mode, but will use self.training which can be # changed after BoundedModule is built. - assert inputs[1].dtype == torch.float32 + assert (inputs[1].dtype == torch.float32 or + inputs[1].dtype == torch.float64) self.ratio = inputs[1] if self.ratio >= 1: raise ValueError('Ratio in dropout should be less than 1') - self.mask = torch.rand(x.shape) > self.ratio + self.mask = torch.rand(x.shape, device=self.ratio.device) > self.ratio return x * self.mask / (1 - self.ratio) def _check_forward(self): @@ -36,7 +37,7 @@ def _check_forward(self): raise RuntimeError('For a model with dropout in the training mode, '\ 'a clean forward pass must be called before bound computation') - def bound_backward(self, last_lA, last_uA, *args): + def bound_backward(self, last_lA, last_uA, *args, **kwargs): empty_A = [(None, None)] * (len(args) -1) if not self.training: return [(last_lA, last_uA), *empty_A], 0, 0 @@ -63,7 +64,7 @@ def interval_propagate(self, *v): if not self.training: return v[0] self._check_forward() - h_L, h_U = v[0] + h_L, h_U = v[0] lower = h_L * self.mask / (1 - self.ratio) upper = h_U * self.mask / (1 - self.ratio) - return lower, upper \ No newline at end of file + return lower, upper diff --git a/auto_LiRPA/operators/dtype.py b/auto_LiRPA/operators/dtype.py index d5dfa38..70ba54d 100644 --- a/auto_LiRPA/operators/dtype.py +++ b/auto_LiRPA/operators/dtype.py @@ -1,25 +1,27 @@ -from .base import * +from .base import * from ..utils import Patches class BoundCast(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.to = attr['to'] + # See values of enum DataType in TensorProto. + # Unsupported: str, uint16, uint32, uint64. self.data_types = [ None, torch.float, torch.uint8, torch.int8, None, torch.int16, torch.int32, torch.int64, - None, torch.bool, torch.float16, torch.float32, - None, None + None, torch.bool, torch.float16, torch.float64, + None, None, torch.complex64, torch.complex128 ] self.type = self.data_types[self.to] - assert self.type is not None + assert self.type is not None, "Unsupported type conversion." self.use_default_ibp = True def forward(self, x): self.type_in = x.dtype return x.to(self.type) - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): if type(last_lA) == Tensor or type(last_lA) == Tensor: lA = last_lA.to(self.type_in) if last_lA is not None else None uA = last_uA.to(self.type_in) if last_uA is not None else None diff --git a/auto_LiRPA/operators/gradient_bounds.py b/auto_LiRPA/operators/gradient_bounds.py index d207af2..deaf608 100644 --- a/auto_LiRPA/operators/gradient_bounds.py +++ b/auto_LiRPA/operators/gradient_bounds.py @@ -1,7 +1,6 @@ """ Bound classes for gradient operators """ import torch import torch.nn.functional as F -import numpy as np from auto_LiRPA.patches import Patches, inplace_unfold from .base import Bound, Interval from .activation_base import BoundActivation @@ -44,7 +43,7 @@ def _maybe_unfold(d_tensor, last_A): class BoundReluGrad(BoundActivation): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.requires_input_bounds = [3] self.recurjac = options.get('recurjac', False) @@ -54,6 +53,8 @@ def relu_grad(preact): return (preact > 0).float() def forward(self, g, g_relu, g_relu_rev, preact): + if g.ndim == preact.ndim + 1: + preact = preact.unsqueeze(1) return g * relu_grad(preact) def interval_propagate(self, *v): @@ -61,11 +62,15 @@ def interval_propagate(self, *v): preact_lower, preact_upper = v[3] relu_grad_lower = relu_grad(preact_lower) relu_grad_upper = relu_grad(preact_upper) + if g_lower.ndim == relu_grad_lower.ndim + 1: + relu_grad_lower = relu_grad_lower.unsqueeze(1) + relu_grad_upper = relu_grad_upper.unsqueeze(1) lower = torch.min(g_lower * relu_grad_lower, g_lower * relu_grad_upper) upper = torch.max(g_upper * relu_grad_lower, g_upper * relu_grad_upper) return lower, upper - def bound_backward(self, last_lA, last_uA, g, g_relu, g_relu_rev, preact): + def bound_backward(self, last_lA, last_uA, g, g_relu, g_relu_rev, preact, + **kwargs): mask_active = (preact.lower > 0).float() mask_inactive = (preact.upper < 0).float() mask_unstable = 1 - mask_active - mask_inactive @@ -159,7 +164,7 @@ def _bound_oneside(last_A, pos_interval=None, neg_interval=None): class BoundConv2dGrad(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.stride = attr['stride'] self.padding = attr['padding'] @@ -180,7 +185,7 @@ def forward(self, *x): stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups, output_padding=self.output_padding) - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): assert not self.is_input_perturbed(1) lA_y = uA_y = lA_bias = uA_bias = None @@ -297,7 +302,7 @@ def interval_propagate(self, *v, C=None): weight = v[1][0] bias = v[2][0] if self.has_bias else None - if norm == np.inf: + if norm == torch.inf: mid = (h_U + h_L) / 2.0 diff = (h_U - h_L) / 2.0 weight_abs = weight.abs() diff --git a/auto_LiRPA/operators/gradient_modules.py b/auto_LiRPA/operators/gradient_modules.py index fd05020..fd5d01f 100644 --- a/auto_LiRPA/operators/gradient_modules.py +++ b/auto_LiRPA/operators/gradient_modules.py @@ -55,7 +55,8 @@ def __init__(self, weight): self.weight = weight def forward(self, grad_last): - return F.linear(grad_last, self.weight.t()) + weight = self.weight.to(grad_last).t() + return F.linear(grad_last, weight) class ReLUGradOp(Function): @@ -65,7 +66,7 @@ class ReLUGradOp(Function): """ @staticmethod def symbolic(_, g, g_relu, g_relu_rev, preact): - return _.op('grad::Relu', g, g_relu, g_relu_rev, preact) + return _.op('grad::Relu', g, g_relu, g_relu_rev, preact).setType(g.type()) @staticmethod def forward(ctx, g, g_relu, g_relu_rev, preact): @@ -81,17 +82,10 @@ def forward(self, g, preact): class ReshapeGrad(Module): def forward(self, grad_last, inp): - return grad_last.reshape( - grad_last.size(0), *inp.shape[1:]) - - -class FlattenGrad(Module): - def __init__(self, in_shape): - super().__init__() - self.in_shape = in_shape - - def forward(self, grad_last): - return torch.reshape(grad_last, [-1] + list(self.in_shape)) + if grad_last.numel() == inp.numel(): + return grad_last.reshape(grad_last.shape[0], *inp.shape[1:]) + else: + return grad_last.reshape(*grad_last.shape[:2], *inp.shape[1:]) class Conv2dGradOp(Function): @@ -102,7 +96,7 @@ def symbolic( return g.op( 'grad::Conv2d', x, w, stride_i=stride, padding_i=padding, dilation_i=dilation, groups_i=groups, output_padding0_i=output_padding0, - output_padding1_i=output_padding1) + output_padding1_i=output_padding1).setType(x.type()) @staticmethod def forward( diff --git a/auto_LiRPA/operators/jacobian.py b/auto_LiRPA/operators/jacobian.py new file mode 100644 index 0000000..6802abd --- /dev/null +++ b/auto_LiRPA/operators/jacobian.py @@ -0,0 +1,34 @@ +import torch +from .base import Bound + + +class JacobianOP(torch.autograd.Function): + @staticmethod + def symbolic(g, output, input): + return g.op('grad::jacobian', output, input).setType(output.type()) + + @staticmethod + def forward(ctx, output, input): + output_ = output.flatten(1) + return torch.zeros( + output.shape[0], output_.shape[-1], *input.shape[1:], + device=output.device) + + +class BoundJacobianOP(Bound): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + + def forward(self, output, input): + return JacobianOP.apply(output, input) + + +class BoundJacobianInit(Bound): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.never_perturbed = True + + def forward(self, x): + x = x.flatten(1) + eye = torch.eye(x.shape[-1], device=x.device) + return eye.unsqueeze(0).repeat(x.shape[0], 1, 1) diff --git a/auto_LiRPA/operators/leaf.py b/auto_LiRPA/operators/leaf.py index e43850e..458264d 100644 --- a/auto_LiRPA/operators/leaf.py +++ b/auto_LiRPA/operators/leaf.py @@ -131,7 +131,7 @@ def forward(self): def bound_forward(self, dim_in): assert 0 - def bound_backward(self, last_lA, last_uA): + def bound_backward(self, last_lA, last_uA, **kwargs): raise ValueError('{} is a BoundInput node and should not be visited here'.format( self.name)) @@ -140,7 +140,7 @@ def interval_propagate(self, *v): self.name)) class BoundParams(BoundInput): - def __init__(self, ori_name, value, perturbation=None): + def __init__(self, ori_name, value, perturbation=None, options=None): super().__init__(ori_name, None, perturbation) self.register_parameter('param', value) self.from_input = False @@ -164,9 +164,10 @@ def forward(self): return self.param.requires_grad_(self.training) class BoundBuffers(BoundInput): - def __init__(self, ori_name, value, perturbation=None): + def __init__(self, ori_name, value, perturbation=None, options=None): super().__init__(ori_name, None, perturbation) self.register_buffer('buffer', value.clone().detach()) + self.from_input = not options.get('buffers', {}).get('no_batchdim', False) def forward(self): return self.buffer diff --git a/auto_LiRPA/operators/linear.py b/auto_LiRPA/operators/linear.py index 3f1aee8..36c39b8 100644 --- a/auto_LiRPA/operators/linear.py +++ b/auto_LiRPA/operators/linear.py @@ -1,14 +1,19 @@ """ Linear (possibly with weight perturbation) or Dot product layers """ from torch import Tensor +from typing import Tuple, List +from .activation_base import BoundOptimizableActivation from .base import * -from .bivariate import BoundMul +from .bivariate import BoundMul, MulHelper from .gradient_modules import LinearGrad +from .leaf import BoundParams from ..patches import Patches, inplace_unfold from .solver_utils import grb +from .clampmult import multiply_by_A_signs +EPS = 1e-2 -class BoundLinear(Bound): - def __init__(self, attr, inputs, output_index, options): +class BoundLinear(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): # Gemm: # A = A if transA == 0 else A.T # B = B if transB == 0 else B.T @@ -21,30 +26,69 @@ def __init__(self, attr, inputs, output_index, options): # Defaults in ONNX self.transA = 0 self.transB = 0 - self.alpha = 1.0 - self.beta = 1.0 + self.alpha_linear = 1.0 + self.beta_linear = 1.0 if attr is not None: self.transA = attr['transA'] if 'transA' in attr else self.transA self.transB = attr['transB'] if 'transB' in attr else self.transB - self.alpha = attr['alpha'] if 'alpha' in attr else self.alpha - self.beta = attr['beta'] if 'beta' in attr else self.beta + self.alpha_linear = attr['alpha'] if 'alpha' in attr else self.alpha_linear + self.beta_linear = attr['beta'] if 'beta' in attr else self.beta_linear + options = options or {} self.opt_matmul = options.get('matmul') + self.splittable = False + + self.mul_helper = MulHelper() + self.use_seperate_weights_for_lower_and_upper_bounds = False + self.share_alphas = options.get('matmul', {}).get('share_alphas', False) def _preprocess(self, a, b, c=None): """Handle tranpose and linear coefficients.""" if self.transA and isinstance(a, Tensor): a = a.transpose(-2,-1) - if self.alpha != 1.0: - a = self.alpha * a + if self.alpha_linear != 1.0: + a = self.alpha_linear * a if not self.transB and isinstance(b, Tensor): - # our code assumes B is transposed (common case), so we transpose B only when it is not transposed in gemm. + # our code assumes B is transposed (common case), so we transpose B + # only when it is not transposed in gemm. b = b.transpose(-2, -1) if c is not None: - if self.beta != 1.0: - c = self.beta * c + if self.beta_linear != 1.0: + c = self.beta_linear * c return a, b, c + def init_opt_parameters(self, start_nodes): + shared_alpha_dims = [] + if self.share_alphas: + # TODO Temporarily an adhoc check for alpha sharing. + count_matmul = len([item for item in self._all_optimizable_activations + if isinstance(item, BoundLinear)]) + if count_matmul >= 6: + shared_alpha_dims = [1, 2, 3] + elif count_matmul >= 4: + shared_alpha_dims = [1, 2] + + input_lb = [getattr(xi, 'lower', None) for xi in self.inputs] + input_ub = [getattr(xi, 'upper', None) for xi in self.inputs] + input_lb = self._preprocess(*input_lb) + input_ub = self._preprocess(*input_ub) + x_l, x_u, y_l, y_u = self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1]) + assert x_l.ndim == y_l.ndim + shape = [1 if i in shared_alpha_dims + else max(x_l.shape[i], y_l.shape[i]) for i in range(x_l.ndim)] + for start_node in start_nodes: + ns, size_s = start_node[:2] + # start_node[3] == False means that this start node is not the final node + # if not start_node[3]: + # # NOTE Experimental code. Please check how it will impact the results. + # size_s = 1 + if isinstance(size_s, torch.Size): + # TODO do not give torch.Size + size_s = prod(size_s) + elif isinstance(size_s, (list, tuple)): + size_s = size_s[0] + self.alpha[ns] = torch.ones(4, size_s, *shape, device=x_l.device) + def forward(self, x, w, b=None): x, w, b = self._preprocess(x, w, b) self.input_shape = self.x_shape = x.shape @@ -74,15 +118,23 @@ def onehot_mult(self, weight, bias, C, batch_size): if C.index.ndim == 1: # Every element in the batch shares the same rows. if weight is not None: - new_weight = self.non_deter_index_select(weight, dim=0, index=index).unsqueeze(1).expand([-1, batch_size] + [-1] * (weight.ndim - 1)) + new_weight = self.non_deter_index_select( + weight, dim=0, index=index + ).unsqueeze(1).expand( + [-1, batch_size] + [-1] * (weight.ndim - 1)) if bias is not None: - new_bias = self.non_deter_index_select(bias, dim=0, index=index).unsqueeze(1).expand(-1, batch_size) + new_bias = self.non_deter_index_select( + bias, dim=0, index=index + ).unsqueeze(1).expand(-1, batch_size) elif C.index.ndim == 2: - # Every element in the batch has different rows, but the number of rows are the same. This essentially needs a batched index_select function. + # Every element in the batch has different rows, but the number of + # rows are the same. This essentially needs a batched index_select function. if weight is not None: - new_weight = batched_index_select(weight.unsqueeze(0), dim=1, index=index) + new_weight = batched_index_select( + weight.unsqueeze(0), dim=1, index=index) if bias is not None: - new_bias = batched_index_select(bias.unsqueeze(0), dim=1, index=index) + new_bias = batched_index_select( + bias.unsqueeze(0), dim=1, index=index) if C.coeffs is not None: if weight is not None: new_weight = new_weight * coeffs.unsqueeze(-1) @@ -94,8 +146,11 @@ def onehot_mult(self, weight, bias, C, batch_size): new_bias = new_bias.transpose(0, 1) return new_weight, new_bias - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, start_node=None, + reduce_bias=True, **kwargs): assert len(x) == 2 or len(x) == 3 + if start_node is not None: + self._start = start_node.name has_bias = len(x) == 3 # x[0]: input node, x[1]: weight, x[2]: bias input_lb = [getattr(xi, 'lower', None) for xi in x] @@ -106,103 +161,149 @@ def bound_backward(self, last_lA, last_uA, *x): lA_y = uA_y = lA_bias = uA_bias = None lbias = ubias = 0 batch_size = last_lA.shape[1] if last_lA is not None else last_uA.shape[1] + weight = input_lb[1] + bias = input_lb[2] if has_bias else None + + def _bound_oneside(last_A): + if last_A is None: + return None, 0 + if isinstance(last_A, torch.Tensor): + # Matrix mode. + # Just multiply this layer's weight into bound matrices, and produce biases. + next_A = last_A.to(weight).matmul(weight) + sum_bias = (last_A.to(bias).matmul(bias) + if has_bias else 0.0) + elif isinstance(last_A, Patches): + # Patches mode. After propagating through this layer, it will become a matrix. + # Reshape the weight matrix as a conv image. + # Weight was in (linear_output_shape, linear_input_shape) + # Reshape it to (linear_input_shape, c, h, w) + reshaped_weight = weight.transpose(0, 1).view( + -1, *last_A.input_shape[1:]) + # After unfolding the shape is + # (linear_input_shape, output_h, output_w, in_c, patch_h, patch_w) + unfolded_weight = inplace_unfold( + reshaped_weight, + kernel_size=last_A.patches.shape[-2:], + stride=last_A.stride, padding=last_A.padding, + inserted_zeros=last_A.inserted_zeros, + output_padding=last_A.output_padding) + if has_bias: + # Do the same for the bias. + reshaped_bias = bias.view(*last_A.input_shape[1:]).unsqueeze(0) + # After unfolding the bias shape is (1, output_h, output_w, in_c, patch_h, patch_w) + unfolded_bias = inplace_unfold( + reshaped_bias, kernel_size=last_A.patches.shape[-2:], + stride=last_A.stride, padding=last_A.padding, + inserted_zeros=last_A.inserted_zeros, + output_padding=last_A.output_padding) + if last_A.unstable_idx is not None: + # In this case, the last_A shape is (num_unstable, batch, out_c, patch_h, patch_w) + # Reshape our weight to (output_h, output_w, 1, in_c, patch_h, patch_w, linear_input_shape), 1 is the inserted batch dim. + unfolded_weight_r = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(2) + # for sparse patches the shape is (unstable_size, batch, in_c, patch_h, patch_w). Batch size is 1 so no need to select here. + # We select in the (output_h, out_w) dimension. + selected_weight = unfolded_weight_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + next_A = torch.einsum('sbchw,sbchwi->sbi', last_A.patches, selected_weight) + if has_bias: + # Reshape our bias to (output_h, output_w, 1, in_c, patch_h, patch_w). We already have the batch dim. + unfolded_bias_r = unfolded_bias.permute(1, 2, 0, 3, 4, 5) + selected_bias = unfolded_bias_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + sum_bias = torch.einsum('sbchw,sbchw->sb', last_A.patches, selected_bias) + else: + # Reshape our weight to (1, 1, output_h, output_w, in_c, patch_h, patch_w, linear_input_shape), 1 is the spec and batch. + selected_weight = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(0).unsqueeze(0) + next_A_r = torch.einsum('sbpqchw,sbpqchwi->spqbi', last_A.patches, selected_weight) + # We return a matrix with flattened spec dimension (corresponding to out_c * out_h * out_w). + next_A = next_A_r.reshape(-1, next_A_r.size(-2), next_A_r.size(-1)) + if has_bias: + # Reshape our bias to (1, 1, output_h, output_w, in_c, patch_h, patch_w) + selected_bias = unfolded_bias.unsqueeze(0) + sum_bias_r = torch.einsum('sbpqchw,sbpqchw->spqb', last_A.patches, selected_bias) + sum_bias = sum_bias_r.reshape(-1, sum_bias_r.size(-1)) + return next_A, sum_bias if has_bias else 0.0 # Case #1: No weight/bias perturbation, only perturbation on input. if not self.is_input_perturbed(1) and (not has_bias or not self.is_input_perturbed(2)): - weight = input_lb[1] - bias = input_lb[2] if has_bias else None # If last_lA and last_uA are indentity matrices. - if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): # FIXME (12/28): we should check last_lA and last_uA separately. Same applies to the weight perturbed, bias perturbed settings. - # Use this layer's W as the next bound matrices. Duplicate the batch dimension. Other dimensions are kept 1. - # Not perturbed, so we can use either lower or upper. - assert last_lA.shape == last_uA.shape - shape_others = prod(last_lA.shape[2:-1]) - A_identity = torch.eye(shape_others).to(weight).view(shape_others, 1, 1, shape_others, 1) - assert last_lA.shape[0] == weight.size(0) * shape_others - w = weight.view(1, weight.size(0), *[1] * (len(last_lA.shape) - 2), weight.size(1)) - w = w * A_identity - - # expand the batch_size dim - lA_x = uA_x = w.view(last_lA.shape[0], 1, *last_lA.shape[2:-1], weight.size(1)).expand(last_lA.shape[0], *last_lA.shape[1:-1], weight.size(1)) - if has_bias: - lbias = ubias = bias.unsqueeze(1).repeat(1, batch_size) - elif isinstance(last_lA, OneHotC) or isinstance(last_uA, OneHotC): - # We need to select several rows from the weight matrix (its shape is output_size * input_size). - lA_x, lbias = self.onehot_mult(weight, bias, last_lA, batch_size) - if last_lA is last_uA: - uA_x = lA_x - ubias = lbias + # FIXME (12/28): we should check last_lA and last_uA separately. + # Same applies to the weight perturbed, bias perturbed settings. + + def multiply_with_weight(weight, set_l: bool, set_u: bool): + lA_x = uA_x = None + lbias = ubias = 0. + if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): + # Use this layer's W as the next bound matrices. Duplicate the batch dimension. Other dimensions are kept 1. + # Not perturbed, so we can use either lower or upper. + assert last_lA.shape == last_uA.shape + shape_others = prod(last_lA.shape[2:-1]) + A_identity = torch.eye( + shape_others, device=weight.device, dtype=weight.dtype + ).view(shape_others, 1, 1, shape_others, 1) + assert last_lA.shape[0] == weight.size(0) * shape_others + w = weight.view( + 1, weight.size(0), *[1] * (len(last_lA.shape) - 2), + weight.size(1)) + w = w * A_identity + + # expand the batch_size dim + tmp_A_x = w.reshape( + last_lA.shape[0], 1, *last_lA.shape[2:-1], weight.size(1) + ).expand(last_lA.shape[0], *last_lA.shape[1:-1], weight.size(1)) + if set_l: + lA_x = tmp_A_x + if set_u: + uA_x = tmp_A_x + + if has_bias: + tmp_bias = bias.unsqueeze(1).repeat(1, batch_size) + if set_l: + lbias = tmp_bias + if set_u: + ubias = tmp_bias + elif isinstance(last_lA, OneHotC) or isinstance(last_uA, OneHotC): + # We need to select several rows from the weight matrix + # (its shape is output_size * input_size). + if set_l: + lA_x, lbias = self.onehot_mult(weight, bias, last_lA, batch_size) + if last_lA is last_uA and set_l and set_u: + uA_x = lA_x + ubias = lbias + elif set_u: + uA_x, ubias = self.onehot_mult(weight, bias, last_uA, batch_size) else: - uA_x, ubias = self.onehot_mult(weight, bias, last_uA, batch_size) + if set_l: + lA_x, lbias = _bound_oneside(last_lA) + if set_u: + uA_x, ubias = _bound_oneside(last_uA) + return lA_x, uA_x, lbias, ubias + + if self.use_seperate_weights_for_lower_and_upper_bounds: + lA_x, _, lbias, _ = multiply_with_weight(input_lb[1], set_l=True, set_u=False) + _, uA_x, _, ubias = multiply_with_weight(input_ub[1], set_l=False, set_u=True) else: - def _bound_oneside(last_A): - if last_A is None: - return None, 0 - if isinstance(last_A, torch.Tensor): - # Matrix mode. - # Just multiply this layer's weight into bound matrices, and produce biases. - next_A = last_A.to(weight).matmul(weight) - sum_bias = (last_A.to(bias).matmul(bias) - if has_bias else 0.0) - elif isinstance(last_A, Patches): - # Patches mode. After propagating through this layer, it will become a matrix. - # Reshape the weight matrix as a conv image. - # Weight was in (linear_output_shape, linear_input_shape) - # Reshape it to (linear_input_shape, c, h, w) - reshaped_weight = weight.transpose(0,1).view(-1, *last_A.input_shape[1:]) - # After unfolding the shape is (linear_input_shape, output_h, output_w, in_c, patch_h, patch_w) - unfolded_weight = inplace_unfold( - reshaped_weight, - kernel_size=last_A.patches.shape[-2:], - stride=last_A.stride, padding=last_A.padding, - inserted_zeros=last_A.inserted_zeros, - output_padding=last_A.output_padding) - if has_bias: - # Do the same for the bias. - reshaped_bias = bias.view(*last_A.input_shape[1:]).unsqueeze(0) - # After unfolding the bias shape is (1, output_h, output_w, in_c, patch_h, patch_w) - unfolded_bias = inplace_unfold(reshaped_bias, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) - if last_A.unstable_idx is not None: - # Reshape our weight to (output_h, output_w, 1, in_c, patch_h, patch_w, linear_input_shape), 1 is the inserted batch dim. - unfolded_weight_r = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(2) - # for sparse patches the shape is (unstable_size, batch, in_c, patch_h, patch_w). Batch size is 1 so no need to select here. - # We select in the (output_h, out_w) dimension. - selected_weight = unfolded_weight_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] - next_A = torch.einsum('sbchw,sbchwi->sbi', last_A.patches, selected_weight) - if has_bias: - # Reshape our bias to (output_h, output_w, 1, in_c, patch_h, patch_w). We already have the batch dim. - unfolded_bias_r = unfolded_bias.permute(1, 2, 0, 3, 4, 5) - selected_bias = unfolded_bias_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] - sum_bias = torch.einsum('sbchw,sbchw->sb', last_A.patches, selected_bias) - else: - # Reshape our weight to (1, 1, output_h, output_w, in_c, patch_h, patch_w, linear_input_shape), 1 is the spec and batch. - selected_weight = unfolded_weight.permute(1, 2, 3, 4, 5, 0).unsqueeze(0).unsqueeze(0) - next_A_r = torch.einsum('sbpqchw,sbpqchwi->spqbi', last_A.patches, selected_weight) - # We return a matrix with flattened spec dimension (corresponding to out_c * out_h * out_w). - next_A = next_A_r.reshape(-1, next_A_r.size(-2), next_A_r.size(-1)) - if has_bias: - # Reshape our bias to (1, 1, output_h, output_w, in_c, patch_h, patch_w) - selected_bias = unfolded_bias.unsqueeze(0) - sum_bias_r = torch.einsum('sbpqchw,sbpqchw->spqb', last_A.patches, selected_bias) - sum_bias = sum_bias_r.reshape(-1, sum_bias_r.size(-1)) - return next_A, sum_bias if has_bias else 0.0 - - lA_x, lbias = _bound_oneside(last_lA) - uA_x, ubias = _bound_oneside(last_uA) + lA_x, uA_x, lbias, ubias = multiply_with_weight(weight, set_l=True, set_u=True) # Case #2: weight is perturbed. bias may or may not be perturbed. elif self.is_input_perturbed(1): + assert not self.use_seperate_weights_for_lower_and_upper_bounds # Obtain relaxations for matrix multiplication. - [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias = self.bound_backward_with_weight(last_lA, last_uA, input_lb, input_ub, x[0], x[1]) + [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias = self.bound_backward_with_weight( + last_lA, last_uA, input_lb, input_ub, x[0], x[1], + reduce_bias=reduce_bias, **kwargs) if has_bias: + assert reduce_bias if x[2].perturbation is not None: - # Bias is also perturbed. Since bias is directly added to the output, in backward mode it is treated - # as an input with last_lA and last_uA as associated bounds matrices. - # It's okay if last_lA or last_uA is eyeC, as it will be handled in the perturbation object. + # Bias is also perturbed. Since bias is directly added to the + # output, in backward mode it is treated as an input with + # last_lA and last_uA as associated bounds matrices. + # It's okay if last_lA or last_uA is eyeC, as it will be + # handled in the perturbation object. lA_bias = last_lA uA_bias = last_uA else: - # Bias not perturbed, so directly adding the bias of this layer to the final bound bias term. + # Bias not perturbed, so directly adding the bias of this + # layer to the final bound bias term. if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): # Bias will be directly added to output. lbias += input_lb[2].unsqueeze(1).repeat(1, batch_size) @@ -213,18 +314,23 @@ def _bound_oneside(last_A): if last_uA is not None: ubias += last_uA.matmul(input_lb[2]) # If not has_bias, no need to compute lA_bias and uA_bias - # Case 3: Only bias is perturbed, weight is not perturbed. elif not self.is_input_perturbed(1) and has_bias and self.is_input_perturbed(2): + assert not self.use_seperate_weights_for_lower_and_upper_bounds + assert reduce_bias if isinstance(last_lA, eyeC) and isinstance(last_uA, eyeC): - # Use this layer's W as the next bound matrices. Duplicate the batch dimension. Other dimensions are kept 1. - lA_x = uA_x = input_lb[1].unsqueeze(1).repeat([1, batch_size] + [1] * (input_lb[1].ndim - 1)) + # Use this layer's W as the next bound matrices. Duplicate the + # batch dimension. Other dimensions are kept 1. + lA_x = uA_x = input_lb[1].unsqueeze(1).repeat( + [1, batch_size] + [1] * (input_lb[1].ndim - 1)) else: lA_x = last_lA.matmul(input_lb[1]) uA_x = last_uA.matmul(input_lb[1]) # It's okay if last_lA or last_uA is eyeC, as it will be handled in the perturbation object. lA_bias = last_lA uA_bias = last_uA + else: + assert not self.use_seperate_weights_for_lower_and_upper_bounds return [(lA_x, uA_x), (lA_y, uA_y), (lA_bias, uA_bias)], lbias, ubias @@ -249,19 +355,43 @@ def _reshape(self, x_l, x_u, y_l, y_u): return x_l, x_u, y_l, y_u - def _relax(self, input_lb, input_ub): - return BoundMul.get_bound_mul(*self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1])) + @staticmethod + # @torch.jit.script + def propagate_A_xy(last_A: Tensor, alpha_pos: Tensor, alpha_neg: Tensor, + beta_pos: Tensor, beta_neg: Tensor, + dim_y: List[int]) -> Tuple[Tensor, Tensor]: + # last_uA has size (batch, spec, output) + last_A_pos = last_A.clamp(min=0).unsqueeze(-1) + last_A_neg = last_A.clamp(max=0).unsqueeze(-1) + # alpha_u has size (batch, spec, output, input) + # uA_x has size (batch, spec, input). + A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) + + alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1) + # beta_u has size (batch, spec, output, input) + # uA_y is for weight matrix, with parameter size (output, input) + # uA_y has size (batch, spec, output, input). This is an element-wise multiplication. + # TODO (for zhouxing/qirui): generalize multiply_by_A_signs() to calculate A_x, + # so last_A_pos and last_A_neg are not needed. This saves memory. + A_y, _ = multiply_by_A_signs(last_A.unsqueeze(-1), beta_pos, beta_neg, None, None) + if len(dim_y) != 0: + A_y = torch.sum(A_y, dim=dim_y) + return A_x, A_y + + def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, + x, y, reduce_bias=True, **kwargs): + # FIXME This is nonlinear. Move to `bivariate.py`. - # FIXME This is nonlinear. Move to `bivariate.py`. - def bound_backward_with_weight(self, last_lA, last_uA, input_lb, input_ub, x, y): # Note: x and y are not tranposed or scaled, and we should avoid using them directly. # Use input_lb and input_ub instead. - alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u = self._relax(input_lb, input_ub) - alpha_l, alpha_u = alpha_l.unsqueeze(0), alpha_u.unsqueeze(0) - beta_l, beta_u = beta_l.unsqueeze(0), beta_u.unsqueeze(0) - x_shape, y_shape = input_lb[0].size(), input_lb[1].size() - gamma_l = torch.sum(gamma_l, dim=-1).reshape(x_shape[0], -1, 1) - gamma_u = torch.sum(gamma_u, dim=-1).reshape(x_shape[0], -1, 1) + (alpha_l, beta_l, gamma_l, + alpha_u, beta_u, gamma_u) = self.mul_helper.get_relaxation( + *self._reshape(input_lb[0], input_ub[0], input_lb[1], input_ub[1]), + self.opt_stage, getattr(self, 'alpha', None), + getattr(self, '_start', None)) + x_shape = input_lb[0].size() + if reduce_bias: + gamma_l = torch.sum(gamma_l, dim=-1) + gamma_u = torch.sum(gamma_u, dim=-1) if len(x.output_shape) != 2 and len(x.output_shape) == len(y.output_shape): dim_y = [-3] @@ -279,32 +409,46 @@ def _bound_oneside(last_A, alpha_pos, beta_pos, gamma_pos, alpha_neg, beta_neg, last_A = (torch.eye(last_A.shape[0], device=last_A.device) .view(last_A.shape[0], 1, *last_A.shape[2:]).expand(last_A.shape)) - # last_uA has size (batch, spec, output) - last_A_pos = last_A.clamp(min=0).unsqueeze(-1) - last_A_neg = last_A.clamp(max=0).unsqueeze(-1) - # alpha_u has size (batch, spec, output, input) - # uA_x has size (batch, spec, input). - A_x = (alpha_pos.transpose(-1, -2).matmul(last_A_pos) + \ - alpha_neg.transpose(-1, -2).matmul(last_A_neg)).squeeze(-1) - # beta_u has size (batch, spec, output, input) - # uA_y is for weight matrix, with parameter size (output, input) - # uA_y has size (batch, spec, output, input). This is an element-wise multiplication. - A_y = last_A_pos * beta_pos + last_A_neg * beta_neg - if len(dim_y) != 0: - A_y = torch.sum(A_y, dim=dim_y) - # last_uA has size (batch, spec, output) - _last_A_pos = last_A_pos.reshape(last_A.shape[0], last_A.shape[1], -1) - _last_A_neg = last_A_neg.reshape(last_A.shape[0], last_A.shape[1], -1) - # gamma_u has size (batch, output, 1) - # ubias has size (batch, spec, 1) - bias = _last_A_pos.transpose(0, 1).matmul(gamma_pos).transpose(0, 1) + \ - _last_A_neg.transpose(0, 1).matmul(gamma_neg).transpose(0, 1) - - bias = bias.squeeze(-1) + A_x, A_y = BoundLinear.propagate_A_xy( + last_A, alpha_pos, alpha_neg, beta_pos, beta_neg, dim_y) + + if reduce_bias: + # last_uA has size (batch, spec, output) + # gamma_u has size (batch, output, 1) + # ubias has size (batch, spec, 1) + if self.opt_stage in ['opt', 'reuse']: + bias = (torch.einsum('sb...,sb...->sb', + last_A.clamp(min=0), gamma_pos) + + torch.einsum('sb...,sb...->sb', + last_A.clamp(max=0), gamma_neg)) + else: + bias = ( + self.get_bias(last_A.clamp(min=0), gamma_pos) + + self.get_bias(last_A.clamp(max=0), gamma_neg) + ) + else: + assert self.batch_dim == 0 + assert self.opt_stage not in ['opt', 'reuse'] + assert dim_y == [-3] + bias = (last_A.unsqueeze(-1).clamp(min=0) * gamma_pos + + last_A.unsqueeze(-1).clamp(max=0) * gamma_neg) + bias_x = bias.sum(dim=-2) + bias_y = bias.sum(dim=-3) + bias = (bias_x, bias_y) return A_x, A_y, bias - lA_x, lA_y, lbias = _bound_oneside(last_lA, alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u) - uA_x, uA_y, ubias = _bound_oneside(last_uA, alpha_u, beta_u, gamma_u, alpha_l, beta_l, gamma_l) + if self.opt_stage in ['opt', 'reuse']: + lA_x, lA_y, lbias = _bound_oneside( + last_lA, alpha_l[0], beta_l[0], gamma_l[0], + alpha_u[0], beta_u[0], gamma_u[0]) + uA_x, uA_y, ubias = _bound_oneside( + last_uA, alpha_u[1], beta_u[1], gamma_u[1], + alpha_l[1], beta_l[1], gamma_l[1]) + else: + lA_x, lA_y, lbias = _bound_oneside( + last_lA, alpha_l, beta_l, gamma_l, alpha_u, beta_u, gamma_u) + uA_x, uA_y, ubias = _bound_oneside( + last_uA, alpha_u, beta_u, gamma_u, alpha_l, beta_l, gamma_l) return [(lA_x, uA_x), (lA_y, uA_y)], lbias, ubias @@ -325,12 +469,14 @@ def _propagate_Linf(x, w): def interval_propagate(self, *v, C=None, w=None): has_bias = self is not None and len(v) == 3 if self is not None: - # This will convert an Interval object to tuple. We need to add perturbation property later. + # This will convert an Interval object to tuple. + # We need to add perturbation property later. v_lb, v_ub = zip(*v) v_lb = self._preprocess(*v_lb) v_ub = self._preprocess(*v_ub) # After preprocess the lower and upper bounds, we make them Intervals again. - v = [Interval.make_interval(bounds[0], bounds[1], bounds[2]) for bounds in zip(v_lb, v_ub, v)] + v = [Interval.make_interval(bounds[0], bounds[1], bounds[2]) + for bounds in zip(v_lb, v_ub, v)] if w is None and self is None: # Use C as the weight, no bias. w, lb, ub = C, torch.tensor(0., device=C.device), torch.tensor(0., device=C.device) @@ -362,7 +508,7 @@ def interval_propagate(self, *v, C=None, w=None): # interval_propagate() of the Linear layer may encounter input with different norms. norm, eps = Interval.get_perturbation(v[0])[:2] - if norm == np.inf: + if norm == torch.inf: interval = BoundLinear._propagate_Linf(v[0], w) center, deviation = interval elif norm > 0: @@ -378,7 +524,9 @@ def interval_propagate(self, *v, C=None, w=None): # mid has dimension [batch, input], w has dimension [output, input]. center = mid.matmul(w.t()) deviation = w.norm(dual_norm, dim=-1) * eps - else: # here we calculate the L0 norm IBP bound of Linear layers, using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020] + else: + # here we calculate the L0 norm IBP bound of Linear layers, + # using the bound proposed in [Certified Defenses for Adversarial Patches, ICLR 2020] norm, eps, ratio = Interval.get_perturbation(v[0]) mid = v[0][0] weight_abs = w.abs() @@ -401,7 +549,7 @@ def interval_propagate_with_weight(self, *v): input_norm, input_eps = Interval.get_perturbation(v[0]) weight_norm, weight_eps = Interval.get_perturbation(v[1]) - if input_norm == np.inf and weight_norm == np.inf: + if input_norm == torch.inf and weight_norm == torch.inf: # A memory-efficient implementation without expanding all the elementary multiplications if self.opt_matmul == 'economic': x_l, x_u = v[0][0], v[0][1] @@ -435,7 +583,7 @@ def interval_propagate_with_weight(self, *v): lower, upper = torch.sum(lower, -1), torch.sum(upper, -1) return lower, upper - elif input_norm == np.inf and weight_norm == 2: + elif input_norm == torch.inf and weight_norm == 2: # This eps is actually the epsilon per row, as only one row is involved for each output element. eps = weight_eps # Input data and weight are Linf perturbed (with upper and lower bounds). @@ -463,7 +611,7 @@ def bound_forward_mul(x_lw: Tensor, x_lb: Tensor, x_uw: Tensor, x_ub: Tensor, w: # w: an optional argument which can be utilized by BoundMatMul def bound_dynamic_forward(self, x, w=None, b=None, C=None, max_dim=None, offset=0): - assert not self.transA and self.alpha == 1.0 and self.transB and self.beta == 1.0 + assert not self.transA and self.alpha_linear == 1.0 and self.transB and self.beta_linear == 1.0 assert not self.is_input_perturbed(1) assert not self.is_input_perturbed(2) @@ -531,22 +679,16 @@ def bound_forward(self, dim_in, x, w=None, b=None, C=None): def bound_forward_with_weight(self, dim_in, x, y): x_unsqueeze = LinearBound( - x.lw.unsqueeze(-2), - x.lb.unsqueeze(-2), - x.uw.unsqueeze(-2), - x.ub.unsqueeze(-2), - x.lower.unsqueeze(-2), - x.upper.unsqueeze(-2), + x.lw.unsqueeze(-2), x.lb.unsqueeze(-2), + x.uw.unsqueeze(-2), x.ub.unsqueeze(-2), + x.lower.unsqueeze(-2), x.upper.unsqueeze(-2), ) y_unsqueeze = LinearBound( - y.lw.unsqueeze(-3), - y.lb.unsqueeze(-3), - y.uw.unsqueeze(-3), - y.ub.unsqueeze(-3), - y.lower.unsqueeze(-3), - y.upper.unsqueeze(-3), + y.lw.unsqueeze(-3), y.lb.unsqueeze(-3), + y.uw.unsqueeze(-3), y.ub.unsqueeze(-3), + y.lower.unsqueeze(-3), y.upper.unsqueeze(-3), ) - res_mul = BoundMul.bound_forward_both_perturbed(dim_in, x_unsqueeze, y_unsqueeze) + res_mul = self.bound_forward_both_perturbed(dim_in, x_unsqueeze, y_unsqueeze) return LinearBound( res_mul.lw.sum(dim=-1) if res_mul.lw is not None else None, res_mul.lb.sum(dim=-1), @@ -568,6 +710,8 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") # current layer weight (100, 1024) this_layer_weight = v[1] + if self.transB == 0: + this_layer_weight = this_layer_weight.transpose(1, 0) #### make sure if this is correct for per-label operations if C is not None: # merge specification C into last layer weights @@ -590,6 +734,16 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") for neuron_idx in range(this_layer_shape[0]): out_lb = out_lbs[neuron_idx] if out_lbs is not None else -float('inf') out_ub = out_ubs[neuron_idx] if out_ubs is not None else float('inf') + if out_ub - out_lb < EPS: + """ + If the inferred lb and ub are too close, it could lead to floating point disagreement + between solver's inferred lb and ub constraints and the computed ones from ab-crown. + Such disagreement can lead to "infeasible" result from the solver for feasible problem. + To avoid so, we relax the box constraints. + This should not affect the solver's result correctness, + since the tighter lb and ub can be inferred by the solver. + """ + out_lb, out_ub = (out_lb + out_ub - EPS) / 2., (out_lb + out_ub + EPS) / 2. lin_expr = 0 if has_bias: @@ -616,28 +770,26 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") model.update() def build_gradient_node(self, grad_upstream): - node_grad = LinearGrad(self.inputs[1].param) + if isinstance(self.inputs[1], BoundParams): + w = self.inputs[1].param + else: + w = self.inputs[1].value + if not self.transB: + w = w.t() + node_grad = LinearGrad(w) return node_grad, (grad_upstream,), [] + def update_requires_input_bounds(self): + self._check_weight_perturbation() + class BoundMatMul(BoundLinear): # Reuse most functions from BoundLinear. - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.transA = 0 self.transB = 0 - self.is_constant_op = False - for inp in inputs: - if BoundMul._check_const_input(inp): - # If any of the two inputs are constant, we do not need input bounds. - # FIXME (05/11/2022): this is just a temporary workaround. We need better way to determine whether we need input bounds, not just for BoundConstant. - self.is_constant_op = True - if self.is_constant_op: - # One input is constant; no bounds required. - self.requires_input_bounds = [1] - else: - # Both inputs are perturbed. Need relaxation. - self.requires_input_bounds = [0, 1] + self.splittable = True def forward(self, x, y): self.x_shape = x.shape @@ -650,12 +802,22 @@ def interval_propagate(self, *v): lower, upper = super().interval_propagate(*v) return lower, upper - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, start_node=None, **kwargs): assert len(x) == 2 - results = super().bound_backward(last_lA, last_uA, *x) + if start_node is not None: + self._start = start_node.name + results = super().bound_backward(last_lA, last_uA, *x, **kwargs) lA_y = results[0][1][0].transpose(-1, -2) if results[0][1][0] is not None else None uA_y = results[0][1][1].transpose(-1, -2) if results[0][1][1] is not None else None - return [results[0][0], (lA_y, uA_y), results[0][2]], results[1], results[2] + if isinstance(results[1], tuple): + lbias = (results[1][0], results[1][1].transpose(-1, -2)) + else: + lbias = results[1] + if isinstance(results[2], tuple): + ubias = (results[2][0], results[2][1].transpose(-1, -2)) + else: + ubias = results[2] + return [results[0][0], (lA_y, uA_y), results[0][2]], lbias, ubias def bound_forward(self, dim_in, x, y): return super().bound_forward(dim_in, x, LinearBound( @@ -667,25 +829,46 @@ def bound_forward(self, dim_in, x, y): y.upper.transpose(-1, -2) if y.upper is not None else None )) + def update_requires_input_bounds(self): + self.is_linear_op = False + for inp in self.inputs: + if not inp.perturbed: + # If any of the two inputs are constant, we do not need input bounds. + self.is_linear_op = True + if self.is_linear_op: + # One input is constant; no bounds required. + self.requires_input_bounds = [] + self.splittable = False + else: + # Both inputs are perturbed. Need relaxation. + self.requires_input_bounds = [0, 1] + self.splittable = True + + class BoundNeg(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) + self.ibp_intermediate = True def forward(self, x): return -x - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): if type(last_lA) == Tensor or type(last_uA) == Tensor: return [(-last_lA if last_lA is not None else None, -last_uA if last_uA is not None else None)], 0, 0 elif type(last_lA) == Patches or type(last_uA) == Patches: if last_lA is not None: - lA = Patches(-last_lA.patches, last_lA.stride, last_lA.padding, last_lA.shape, unstable_idx=last_lA.unstable_idx, output_shape=last_lA.output_shape) + lA = Patches(-last_lA.patches, last_lA.stride, last_lA.padding, + last_lA.shape, unstable_idx=last_lA.unstable_idx, + output_shape=last_lA.output_shape) else: lA = None if last_uA is not None: - uA = Patches(-last_uA.patches, last_uA.stride, last_uA.padding, last_uA.shape, unstable_idx=last_uA.unstable_idx, output_shape=last_uA.output_shape) + uA = Patches(-last_uA.patches, last_uA.stride, last_uA.padding, + last_uA.shape, unstable_idx=last_uA.unstable_idx, + output_shape=last_uA.output_shape) else: uA = None return [(lA, uA)], 0, 0 @@ -700,7 +883,7 @@ def interval_propagate(self, *v): class BoundCumSum(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.use_default_ibp = True @@ -709,14 +892,14 @@ def forward(self, x, axis): return torch.cumsum(x, axis) class BoundIdentity(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.use_default_ibp = True def forward(self, x): return x - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): return [(last_lA, last_uA)], 0, 0 def bound_forward(self, dim_in, x): diff --git a/auto_LiRPA/operators/logical.py b/auto_LiRPA/operators/logical.py index 179a7ac..e70ccc6 100644 --- a/auto_LiRPA/operators/logical.py +++ b/auto_LiRPA/operators/logical.py @@ -3,9 +3,6 @@ class BoundWhere(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - def forward(self, condition, x, y): return torch.where(condition.to(torch.bool), x, y) @@ -14,7 +11,7 @@ def interval_propagate(self, *v): condition = v[0][0] return tuple([torch.where(condition, v[1][j], v[2][j]) for j in range(2)]) - def bound_backward(self, last_lA, last_uA, condition, x, y): + def bound_backward(self, last_lA, last_uA, condition, x, y, **kwargs): assert torch.allclose(condition.lower.float(), condition.upper.float()) assert self.from_input mask = condition.lower.float() @@ -33,8 +30,10 @@ def _bound_oneside(last_A): return [(None, None), (lA_x, uA_x), (lA_y, uA_y)], 0, 0 class BoundNot(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - def forward(self, x): - return x.logical_not() \ No newline at end of file + return x.logical_not() + + +class BoundEqual(Bound): + def forward(self, x, y): + return x == y diff --git a/auto_LiRPA/operators/nonlinear.py b/auto_LiRPA/operators/nonlinear.py index 601377d..20af28e 100644 --- a/auto_LiRPA/operators/nonlinear.py +++ b/auto_LiRPA/operators/nonlinear.py @@ -1,388 +1,245 @@ """Unary nonlinearities other than activation functions.""" -import math import torch -from .activation_base import BoundActivation -from .activations import BoundTanh -from .base import epsilon, LinearBound - - -class BoundSin(BoundActivation): - # Lookup tables shared by all BoundSin classes. - xl_lower_tb = None - xl_upper_tb = None - xu_lower_tb = None - xu_upper_tb = None - func, d_func = torch.sin, torch.cos - n_table_entries = 1001 - - @staticmethod - def n_crossing(start, end, s): - """Check how many times we will encounter value s + k*2*pi within start and end for any integer k.""" - dtype = start.dtype - cycles = torch.floor((end - start) / (2 * math.pi)) # Number of 2pi cycles. - # Move s and end to the same 2 * pi cycle as start. - dist = torch.floor((s - start) / (2 * math.pi)) - real_s = s - dist * 2 * math.pi - real_end = end - cycles * 2 * math.pi - # assert (real_end >= start - 2 ** (-20)).all() - return (real_s >= start).to(dtype) * (real_s <= real_end).to(dtype) + cycles - - @staticmethod - def get_intersection(start, end, c, theta=0.): - """Get the number of intersections between y = sin(x + theta) and y = c between start and end.""" - # Use arcsine to find the first 2 intersections. - crossing1 = torch.arcsin(c) - theta - crossing2 = math.pi - crossing1 - 2 * theta # Problematic at exact 1/2 pi, but ok in our case (happens only when lb=ub). - return BoundSin.n_crossing(start, end, crossing1) + BoundSin.n_crossing(start, end, crossing2) - - @staticmethod - def get_bounding_slope(xl, xu, c, theta=0.): - """Find the point between xl and xu such that the tangent line at that point is a lower/upper bound.""" - dtype = xl.dtype - crossing1 = torch.arcsin(c) - theta # output is [-0.5 pi, 0.5 pi] - theta. For cosine, theta=0.5 pi and crossing point is between -pi to 0. - crossing2 = math.pi - crossing1 - 2 * theta # output is [0.5pi, 1.5pi] - theta. For cosine, it is between 0 and pi. - # Find the crossing point between xl and xu. - # First see how xl is away from the [-0.5pi, 1.5pi] range for sine or [-pi, pi] range for cosine. - cycles1 = torch.floor((crossing1 - xl) / (2 * math.pi)) * 2 * math.pi - # Move the two crossing points to the same cycle as xl. - crossing1_moved = crossing1 - cycles1 - cycles2 = torch.floor((crossing2 - xl) / (2 * math.pi)) * 2 * math.pi - crossing2_moved = crossing2 - cycles2 - # Then check which crossing point is the actual tangent point between xl and xu. - crossing1_used = (crossing1_moved >= xl).to(dtype) * (crossing1_moved <= xu).to(dtype) - crossing2_used = (crossing2_moved >= xl).to(dtype) * (crossing2_moved <= xu).to(dtype) - crossing_point = crossing1_used * crossing1_moved + crossing2_used * crossing2_moved - # print(f'c1={crossing1.item():.05f}, c2={crossing2.item():.05f}, cy1={cycles1.item():.05f}, cy2={cycles2.item():.05f}, c1m={crossing1_moved.item():.05f}, c2m={crossing2_moved.item():.05f}, u1={crossing1_used.item()}, u2={crossing2_used.item()}, xl={xl.item():.05f}, xu={xu.item():.05f}') - return crossing_point - - @staticmethod - def check_bound(tangent_point, x): - """Check whether the tangent line at tangent_point is a valid lower/upper bound for x.""" - # evaluate the value of the tangent line at x and see it is >= 0 or <=0. - d = BoundSin.d_func(tangent_point) - val = d * (x - tangent_point) + BoundSin.func(tangent_point) - # We want a positive margin when finding a lower line, but as close to 0 as possible. - # We want a negative margin when finding a upper line, but as close to 0 as possible. - margin = BoundSin.func(x) - val - return margin - - @staticmethod +import torch.nn.functional as F +from .activation_base import BoundActivation, BoundOptimizableActivation +from .base import * +from .clampmult import multiply_by_A_signs +from .tanh import BoundTanh + + +# TODO too much code in this class is a duplicate of BoundTanh +class BoundOptimizableNonLinear(BoundTanh): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + # FIXME temporary: precompute=False + super().__init__(attr, inputs, output_index, options, + precompute=False) + # activation function needs to be nn.module + self.splittable = True + self.act_func = None + self.d_act_func = None + self.inflections = [] + self.extremes = [] + + def branch_input_domain(self, lb, ub): + lower = lb + upper = ub + num_inflection = torch.zeros_like(lower) + inflection_mat = lower + for inflection in self.inflections: + num_inflection += torch.logical_and( + lower <= inflection, upper >= inflection) + inflection_mat = torch.where( + torch.logical_and(lower <= inflection, upper >= inflection), + torch.tensor(inflection, device=lb.device), inflection_mat) + inflection_mask = num_inflection <= 1. + + extreme_mask = torch.ones_like(lower) + for extreme in self.extremes: + extreme_mask *= torch.logical_or(lower >= extreme, upper <= extreme) + + self.sigmoid_like_mask = torch.logical_and(inflection_mask, extreme_mask) + self.branch_mask = torch.logical_xor(torch.ones_like(lower), self.sigmoid_like_mask) + self.inflection_mat = torch.where(self.sigmoid_like_mask, inflection_mat, lower) + + self.mask_neg = torch.logical_and((self.d2_act_func(lower) >= 0), + torch.logical_and((self.d2_act_func(upper) >= 0), + self.sigmoid_like_mask)) + self.mask_pos = torch.logical_and((self.d2_act_func(lower) < 0), + torch.logical_and((self.d2_act_func(upper) < 0), + self.sigmoid_like_mask)) + self.mask_both = torch.logical_xor(self.sigmoid_like_mask, + torch.logical_or(self.mask_neg, self.mask_pos)) + + self.convex_concave = torch.logical_and(self.mask_both, + (self.d2_act_func(lower) >= 0)) + self.concave_convex = torch.logical_xor(self.mask_both, self.convex_concave) + + # FIXME @torch.no_grad() - def get_lower_left_bound(xl, steps=20): - """Get a global lower bound given lower bound on x. Return slope and intercept.""" - dtype = xl.dtype - # Constrain xl into the -0.5 pi to 1.5 pi region. - cycles = torch.floor((xl + 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) - xl = xl - cycles - use_tangent_line = (xl >= math.pi).to(dtype) - # Case 1: xl > pi, Lower tangent line is the only possible lower bound. - case1_d = BoundSin.d_func(xl) - case1_b = BoundSin.func(xl) - case1_d * (xl + cycles) - # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region. - left = math.pi * torch.ones_like(xl) - # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1). - right = (1.5 * math.pi) * torch.ones_like(xl) - last_right = right.clone() - for i in range(steps): - mid = (left + right) / 2. - margin = BoundSin.check_bound(mid, xl) - pos_mask = (margin > 0).to(dtype) # We want to margin > 0 but at small as possible. - neg_mask = 1.0 - pos_mask - right = mid * pos_mask + right * neg_mask # We have positive margin, reduce right hand side. - last_right = mid * pos_mask + last_right * neg_mask # Always sound, since the margin is positive. - left = mid * neg_mask + left * pos_mask - case2_d = BoundSin.d_func(last_right) - case2_b = BoundSin.func(last_right) - case2_d * (last_right + cycles) - d = case1_d * use_tangent_line + case2_d * (1. - use_tangent_line) - b = case1_b * use_tangent_line + case2_b * (1. - use_tangent_line) - # Return slope and bias. - return [d, b] - - @staticmethod - @torch.no_grad() - def get_upper_left_bound(xl, steps=20): - dtype = xl.dtype - """Get a global upper bound given lower bound on x. Return slope and intercept.""" - # Constrain xl into the -0.5 pi to 1.5 pi region. - cycles = torch.floor((xl - 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) - xl = xl - cycles - use_tangent_line = (xl >= 2.0 * math.pi).to(dtype) - # Case 1: xl > pi, Lower tangent line is the only possible lower bound. - case1_d = BoundSin.d_func(xl) - case1_b = BoundSin.func(xl) - case1_d * (xl + cycles) - # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region. - left = (2.0 * math.pi) * torch.ones_like(xl) - # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1). - right = (2.5 * math.pi) * torch.ones_like(xl) - last_right = right.clone() - for i in range(steps): - mid = (left + right) / 2. - margin = BoundSin.check_bound(mid, xl) - pos_mask = (margin > 0).to(dtype) # We want to margin < 0 but at small as possible. - neg_mask = 1.0 - pos_mask - right = mid * neg_mask + right * pos_mask # We have positive margin, reduce right hand side. - last_right = mid * neg_mask + last_right * pos_mask # Always sound, since the margin is positive. - left = mid * pos_mask + left * neg_mask - case2_d = BoundSin.d_func(last_right) - case2_b = BoundSin.func(last_right) - case2_d * (last_right + cycles) - d = case1_d * use_tangent_line + case2_d * (1. - use_tangent_line) - b = case1_b * use_tangent_line + case2_b * (1. - use_tangent_line) - # Return slope and bias. - return [d, b] - - @staticmethod - @torch.no_grad() - def get_lower_right_bound(xu, steps=20): - """Get a global lower bound given upper bound on x. Return slope and intercept.""" - # Constrain xu into the -0.5 pi to 1.5 pi region. - cycles = torch.floor((xu + 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) - xu = xu - cycles - d, _ = BoundSin.get_lower_left_bound(math.pi - xu, steps) - return [-d, BoundSin.func(xu) + d * (xu + cycles)] - - @staticmethod - @torch.no_grad() - def get_upper_right_bound(xu, steps=20): - """Get a global upper bound given upper bound on x. Return slope and intercept.""" - # Constrain xu into the 0.5 pi to 2.5 pi region. - cycles = torch.floor((xu - 0.5 * math.pi) / (2 * math.pi)) * (2 * math.pi) - xu = xu - cycles - d, _ = BoundSin.get_upper_left_bound(3 * math.pi - xu, steps) - return [-d, BoundSin.func(xu) + d * (xu + cycles)] - - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - # Bound limits used by IBP. - self.max_point = math.pi / 2 - self.min_point = math.pi * 3 / 2 - - self.all_table_x = torch.linspace(0, 2 * math.pi, BoundSin.n_table_entries, device=self.device) - if BoundSin.xl_lower_tb is None: - # Generate look-up tables. - BoundSin.xl_lower_tb = BoundSin.get_lower_left_bound(self.all_table_x) - BoundSin.xl_upper_tb = BoundSin.get_upper_left_bound(self.all_table_x) - BoundSin.xu_lower_tb = BoundSin.get_lower_right_bound(self.all_table_x) - BoundSin.xu_upper_tb = BoundSin.get_upper_right_bound(self.all_table_x) - BoundSin.xl_lower_tb[0], BoundSin.xl_lower_tb[1] = BoundSin.xl_lower_tb[0].to(self.device), BoundSin.xl_lower_tb[1].to(self.device) - BoundSin.xl_upper_tb[0], BoundSin.xl_upper_tb[1] = BoundSin.xl_upper_tb[0].to(self.device), BoundSin.xl_upper_tb[1].to(self.device) - BoundSin.xu_lower_tb[0], BoundSin.xu_lower_tb[1] = BoundSin.xu_lower_tb[0].to(self.device), BoundSin.xu_lower_tb[1].to(self.device) - BoundSin.xu_upper_tb[0], BoundSin.xu_upper_tb[1] = BoundSin.xu_upper_tb[0].to(self.device), BoundSin.xu_upper_tb[1].to(self.device) - - @staticmethod - def interpoloate(x, lower_x, upper_x, lower_y, upper_y): - # x = torch.clamp(x, min=lower_x, max=upper_x) # For pytorch >= 1.11 - x = torch.max(torch.min(x, upper_x), lower_x) - ratio = (x - lower_x) / (upper_x - lower_x + 1e-10) - return lower_y * (1. - ratio) + upper_y * ratio - - def get_bound_tb(self, tb, x): - """Find lower or upper bounds from lookup table.""" - step = 2 * math.pi / (BoundSin.n_table_entries - 1) - # Move to 0 to 2 pi region. - cycles = torch.floor(x / (2 * math.pi)) * (2 * math.pi) - x = torch.clamp(x - cycles, min=0, max=2 * math.pi) - # Find the indice within the lookup table from 0 - 2pi. - indices = x.div(step).long() - # Intepolate the nearest d and b. This has better differentiability. - # Another option is to always take the lower/upper side (better soundness). - upper_indices = torch.clamp(indices + 1, max=BoundSin.n_table_entries-1) - lower_x = self.all_table_x[indices] - upper_x = self.all_table_x[upper_indices] - # print(indices.item(), lower_x.item(), upper_x.item(), tb[0][indices].item(), tb[0][upper_indices].item()) - d = self.interpoloate(x, lower_x, upper_x, tb[0][indices], tb[0][upper_indices]) - b = self.interpoloate(x, lower_x, upper_x, tb[1][indices], tb[1][upper_indices]) - return d, b - d * cycles - - def forward(self, x): - return torch.sin(x) - - def interval_propagate(self, *v): - # Check if a point is in [l, u], considering the 2pi period - def check_crossing(ll, uu, point): - return ((((uu - point) / (2 * math.pi)).floor() - ((ll - point) / (2 * math.pi)).floor()) > 0).to(h_Ls.dtype) - h_L, h_U = v[0][0], v[0][1] - h_Ls, h_Us = self.forward(h_L), self.forward(h_U) - # If crossing pi/2, then max is fixed 1.0 - max_mask = check_crossing(h_L, h_U, self.max_point) - # If crossing pi*3/2, then min is fixed -1.0 - min_mask = check_crossing(h_L, h_U, self.min_point) - ub = torch.max(h_Ls, h_Us) - ub = max_mask + (1 - max_mask) * ub - lb = torch.min(h_Ls, h_Us) - lb = - min_mask + (1 - min_mask) * lb - return lb, ub + def precompute_relaxation(self, func, dfunc, x_limit = 500): + return super().precompute_relaxation('nonlinear', func, dfunc, x_limit) + + def generate_d_lower_upper(self, lower, upper): + # Indices of neurons with input upper bound >=0, whose optimal slope to + # lower bound the function was pre-computed. + # Note that for neurons with also input lower bound >=0, + # they will be masked later. + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), + (upper / self.step_pre).to(torch.long).reshape(-1) + ) + 1 + # Lookup the lower bound slope from the pre-computed table. + d_lower = torch.index_select(self.d_lower, 0, index).view(lower.shape) + + # Indices of neurons with lower bound <=0, whose optimal slope to upper + # bound the function was pre-computed. + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + (lower / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + d_upper = torch.index_select(self.d_upper, 0, index).view(upper.shape) + return d_lower, d_upper + + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + shape = [size_spec] + list(l.shape) + alpha = torch.empty(10, *shape, device=l.device) + alpha.data[:4] = ((l + u) / 2).unsqueeze(0).expand(4, *shape) + alpha.data[4:6] = self.tp_both_lower_init[name_start].expand(2, *shape) + alpha.data[6:8] = self.tp_both_upper_init[name_start].expand(2, *shape) + return alpha + + def bound_relax_impl_sigmoid(self, lb, ub, func, dfunc): + # When self.x_limit is large enough, torch.tanh(self.x_limit)=1, + # and thus clipping is valid + lower = lb + upper = ub + y_l, y_u = func(lower), func(upper) + + # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)). + k_direct = k = (y_u - y_l) / (upper - lower).clamp(min=1e-8) + + # Fixed bounds that cannot be optimized. self.mask_neg are the masks for neurons with upper bound <= 0. + # Upper bound for the case of input lower bound <= 0, is always the direct line. + self.add_linear_relaxation( + mask=self.mask_neg, type='upper', k=k, x0=lower, y0=y_l) + # Lower bound for the case of input upper bound >= 0, is always the direct line. + self.add_linear_relaxation( + mask=self.mask_pos, type='lower', k=k, x0=lower, y0=y_l) + + if self.use_precompute: + d_lower, d_upper = self.generate_d_lower_upper(lower, upper) + else: + d_lower = self.convex_concave * lower + self.concave_convex * upper + d_upper = self.convex_concave * upper + self.concave_convex * lower + + if self.opt_stage in ['opt', 'reuse']: + if not hasattr(self, 'alpha'): + # Raise an error if alpha is not created. + self._no_bound_parameters() + ns = self._start + + self.alpha[ns].data[0:2, :] = torch.max( + torch.min(self.alpha[ns][0:2, :], upper), lower) + self.alpha[ns].data[2:4, :] = torch.max( + torch.min(self.alpha[ns][2:4, :], upper), lower) + self.alpha[ns].data[4:6, :] = ( + self.convex_concave * torch.max(lower, torch.min(self.alpha[ns][4:6, :], d_lower)) + + self.concave_convex * torch.min(upper, torch.max(self.alpha[ns][4:6, :], d_lower))) + self.alpha[ns].data[6:8, :] = ( + self.convex_concave * torch.min(upper, torch.max(self.alpha[ns][6:8, :], d_upper)) + + self.concave_convex * torch.max(lower, torch.min(self.alpha[ns][6:8, :], d_upper))) + + # shape [2, out_c, n, c, h, w]. + tp_pos = self.alpha[ns][0:2, :] + tp_neg = self.alpha[ns][2:4, :] + tp_both_lower = self.alpha[ns][4:6, :] + tp_both_upper = self.alpha[ns][6:8, :] + + # No need to use tangent line, when the tangent point is at the left + # side of the preactivation lower bound. Simply connect the two sides. + mask_direct = torch.logical_or( + torch.logical_and(self.convex_concave, k_direct < dfunc(lower)), + torch.logical_and(self.concave_convex, k_direct > dfunc(upper))) + self.add_linear_relaxation( + mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), type='lower', + k=dfunc(tp_both_lower), x0=tp_both_lower, y0=func(tp_both_lower)) + + mask_direct = torch.logical_or( + torch.logical_and(self.convex_concave, k_direct < dfunc(upper)), + torch.logical_and(self.concave_convex, k_direct > dfunc(lower))) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), type='upper', + k=dfunc(tp_both_upper), x0=tp_both_upper, y0=func(tp_both_upper)) + + self.add_linear_relaxation( + mask=self.mask_neg, type='lower', k=dfunc(tp_neg), + x0=tp_neg, y0=func(tp_neg)) + self.add_linear_relaxation( + mask=self.mask_pos, type='upper', k=dfunc(tp_pos), + x0=tp_pos, y0=func(tp_pos)) + else: + # Not optimized (vanilla CROWN bound). + # Use the middle point slope as the lower/upper bound. Not optimized. + m = (lower + upper) / 2 + y_m = func(m) + k = dfunc(m) + # Lower bound is the middle point slope for the case input upper bound <= 0. + # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation( + mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m) + # Upper bound is the middle point slope for the case input lower bound >= 0. + # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation( + mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m) + + # Now handle the case where input lower bound <=0 and upper bound >= 0. + # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound. + k = dfunc(d_lower) + y0 = func(d_lower) + if self.opt_stage == 'init': + # Initialize optimizable slope. + ns = self._start + self.tp_both_lower_init[ns] = d_lower.detach() + # Another possibility is to use the direct line as the lower bound, when this direct line does not intersect with f. + # This is only valid when the slope at the input lower bound has a slope greater than the direct line. + mask_direct = torch.logical_or( + torch.logical_and(self.convex_concave, k_direct < dfunc(lower)), + torch.logical_and(self.concave_convex, k_direct > dfunc(upper))) + self.add_linear_relaxation( + mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + # Otherwise we do not use the direct line, we use the d_lower slope. + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), + type='lower', k=k, x0=d_lower, y0=y0) + + # Do the same for the upper bound side when input lower bound <=0 and upper bound >= 0. + k = dfunc(d_upper) + y0 = func(d_upper) + if self.opt_stage == 'init': + ns = self._start + self.tp_both_upper_init[ns] = d_upper.detach() + self.tmp_lower = lb.detach() + self.tmp_upper = ub.detach() + mask_direct = torch.logical_or( + torch.logical_and(self.convex_concave, k_direct < dfunc(upper)), + torch.logical_and(self.concave_convex, k_direct > dfunc(lower))) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), + type='upper', k=k, x0=d_upper, y0=y0) + + def generate_inflections(self, lb, ub): + raise NotImplementedError def bound_relax_impl(self, lb, ub): - dtype = lb.dtype - # Case 1: Connect the two points as a line - sub = self.func(ub) - slb = self.func(lb) - mid = (sub + slb) / 2. - smid = self.func((ub + lb) / 2) - case1_line_slope = (sub - slb) / (ub - lb + 1e-10) - case1_line_bias = slb - case1_line_slope * lb - gap = smid - mid - # Check if there are crossings between the line and the sin function. - grad_crossings = self.get_intersection(lb, ub, case1_line_slope, theta=0.5 * math.pi) - # If there is no crossing, then we can connect the two points together as a lower/upper bound. - use_line = grad_crossings == 1 - # Connected line is the upper bound. - upper_use_line = torch.logical_and(gap < 0, use_line) - # Connected line is the lower bound. - lower_use_line = torch.logical_and(gap >= 0, use_line) - # For the other bound, use the tangent line. - case1_tangent_point = self.get_bounding_slope(lb, ub, case1_line_slope, theta=0.5 * math.pi) - case1_tangent_slope = case1_line_slope # Use the same slope so far. - stangent = self.func(case1_tangent_point) - case1_tangent_bias = stangent - case1_tangent_slope * case1_tangent_point - # Choose the lower/upper based on gap. - case1_lower_slope = lower_use_line * case1_line_slope + upper_use_line * case1_tangent_slope - case1_lower_bias = lower_use_line * case1_line_bias + upper_use_line * case1_tangent_bias - case1_upper_slope = upper_use_line * case1_line_slope + lower_use_line * case1_tangent_slope - case1_upper_bias = upper_use_line * case1_line_bias + lower_use_line * case1_tangent_bias - - # Case 2: we will try the global lower/upper bounds at lb and ub. - # For the points and lb and ub, we can construct both lower and upper bounds. - left_lower = self.get_bound_tb(BoundSin.xl_lower_tb, lb) # slope, bias. - left_upper = self.get_bound_tb(BoundSin.xl_upper_tb, lb) - right_lower = self.get_bound_tb(BoundSin.xu_lower_tb, ub) - right_upper = self.get_bound_tb(BoundSin.xu_upper_tb, ub) - # Determine which lower bound is tighter. - left_lower_error = sub - (left_lower[0] * ub + left_lower[1]) - right_lower_error = slb - (right_lower[0] * lb + right_lower[1]) - left_upper_error = (left_upper[0] * ub + left_upper[1]) - sub - right_upper_error = (right_upper[0] * lb + right_upper[1]) - slb - use_left_lower = (left_lower_error < right_lower_error).to(dtype) - use_right_lower = 1. - use_left_lower - use_left_upper = (left_upper_error < right_upper_error).to(dtype) - use_right_upper = 1. - use_left_upper - # Choose the slope and bias in this case. - case_2_lower_slope = use_left_lower * left_lower[0] + use_right_lower * right_lower[0] - case_2_lower_bias = use_left_lower * left_lower[1] + use_right_lower * right_lower[1] - case_2_upper_slope = use_left_upper * left_upper[0] + use_right_upper * right_upper[0] - case_2_upper_bias = use_left_upper * left_upper[1] + use_right_upper * right_upper[1] - - # Finally, choose between case 1 and case 2. - use_line = use_line.to(dtype) - not_use_line = 1. - use_line - lower_slope = use_line * case1_lower_slope + not_use_line * case_2_lower_slope - lower_bias = use_line * case1_lower_bias + not_use_line * case_2_lower_bias - upper_slope = use_line * case1_upper_slope + not_use_line * case_2_upper_slope - upper_bias = use_line * case1_upper_bias + not_use_line * case_2_upper_bias - # print(gap, lower_slope, lower_bias, upper_slope, upper_bias) - return lower_slope, lower_bias, upper_slope, upper_bias - - def bound_relax(self, x): - lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(x.lower, x.upper) - self.lw = lower_slope - self.lb = lower_bias - self.uw = upper_slope - self.ub = upper_bias - - -class BoundCos(BoundSin): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.max_point = 0.0 - self.min_point = math.pi - - def forward(self, x): - return torch.cos(x) - - def bound_relax(self, x): - # Shift the input by 0.5*pi, and shifting the linear bounds back. - lb = x.lower + 0.5 * math.pi - ub = x.upper + 0.5 * math.pi + raise NotImplementedError + + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + lb = x.lower + ub = x.upper + self.generate_inflections(lb, ub) + self.branch_input_domain(lb, ub) + self.bound_relax_impl_sigmoid(lb, ub, self.act_func, self.d_act_func) lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(lb, ub) - self.lw = lower_slope - self.lb = lower_slope * (0.5 * math.pi) + lower_bias - self.uw = upper_slope - self.ub = upper_slope * (0.5 * math.pi) + upper_bias - - -class BoundAtan(BoundTanh): - def __init__(self, attr, inputs, output_index, options): - super(BoundTanh, self).__init__(attr, inputs, output_index, options) - self.precompute_relaxation('arctan', torch.arctan, self.darctan) - # Alpha dimension is (4, 2, output_shape, batch, *shape) for S-shaped functions. - self.alpha_batch_dim = 3 - - def forward(self, x): - return torch.arctan(x) - - def darctan(self, x): - return (x.square() + 1.).reciprocal() - - def bound_relax(self, x): - self.bound_relax_impl(x, torch.arctan, self.darctan) - + self.lw = self.lw * self.sigmoid_like_mask + self.branch_mask * lower_slope + self.lb = self.lb * self.sigmoid_like_mask + self.branch_mask * lower_bias + self.uw = self.uw * self.sigmoid_like_mask + self.branch_mask * upper_slope + self.ub = self.ub * self.sigmoid_like_mask + self.branch_mask * upper_bias -class BoundTan(BoundAtan): - """ - The implementation of BoundTan is based on the S-shaped BoundAtan. We use the bounds from its - inverse function and directly convert the bounds of the inverse function to bounds of the original - function. This trick allows us to quickly implement bounds on inverse functions. - """ - def forward(self, x): - return torch.tan(x) - - def _check_bounds(self, lower, upper): - # Lower and upper bounds must be within the same [-½π, ½π] region. - lower_periods = torch.floor((lower + 0.5 * torch.pi) / torch.pi) - upper_periods = torch.floor((upper + 0.5 * torch.pi) / torch.pi) - if not torch.allclose(lower_periods, upper_periods): - print('Tan preactivation lower bounds:\n', lower) - print('Tan preactivation upper bounds:\n', upper) - raise ValueError("BoundTan received pre-activation bounds that produce infinity. " - "The preactivation bounds are too loose. Try to reduce perturbation region.") - # Return the period number for each neuron. - # Period is 0 => bounds are within [-½π, ½π], - # Period is 1 => bounds are within [-½π + π, ½π + π] - # Period is -1 => bounds are within [-½π - π, ½π - π] - return lower_periods - - def _init_masks(self, x): - # The masks now must consider the periodicity. - lower = torch.remainder(x.lower + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi - upper = torch.remainder(x.upper + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi - self.mask_pos = lower >= 0 - self.mask_neg = upper <= 0 - self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg)) - - def interval_propagate(self, *v): - # We need to check if the input lower and upper bounds are within the same period. - # Otherwise the bounds become infinity. - concrete_lower, concrete_upper = v[0][0], v[0][1] - self._check_bounds(concrete_lower, concrete_upper) - return super().interval_propagate(*v) - - def bound_relax(self, x): - periods = self._check_bounds(x.lower, x.upper) - periods = torch.pi * periods - # Create a fake x with inversed lower and upper. - inverse_x = lambda: None - inverse_x.lower = torch.tan(x.lower) - inverse_x.upper = torch.tan(x.upper) - super().bound_relax(inverse_x) - # Lower slope, lower bias, upper slope and upper bias are saved to - # self.lw, self.lb, self.uw, self.ub. We need to reverse them. - # E.g., y = self.lw * x + self.lb, now becomes x = 1./self.lw * y - self.lb / self.lw - # Additionally, we need to add the missing ½π periods. - new_upper_slope = 1. / self.lw - new_upper_bias = - self.lb / self.lw - periods / self.lw - new_lower_slope = 1. / self.uw - new_lower_bias = - self.ub / self.uw - periods / self.uw - self.lw = new_lower_slope - self.lb = new_lower_bias - self.uw = new_upper_slope - self.ub = new_upper_bias - - -class BoundExp(BoundActivation): - def __init__(self, attr, inputs, output_index, options): +class BoundExp(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) - self.options = options.get('exp') + self.options = options.get('exp', {}) self.max_input = 0 def forward(self, x): @@ -417,7 +274,7 @@ def bound_forward(self, dim_in, x): return LinearBound(lw, lb, uw, ub) - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): + def bound_backward(self, last_lA, last_uA, x, **kwargs): # Special case when computing log_softmax (FIXME: find a better solution, this trigger condition is not reliable). if self.loss_fusion and last_lA is None and last_uA is not None and torch.min( last_uA) >= 0 and x.from_input: @@ -433,9 +290,9 @@ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None) adjusted_upper = x.upper - self.max_input # relaxation for upper bound only (used in loss fusion) exp_l, exp_u = torch.exp(adjusted_lower), torch.exp(adjusted_upper) - k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower + epsilon) + k = (exp_u - exp_l) / (adjusted_upper - adjusted_lower).clamp(min=1e-8) if k.requires_grad: - k = k.clamp(min=1e-6) + k = k.clamp(min=1e-8) uA = last_uA * k.unsqueeze(0) ubias = last_uA * (-adjusted_lower * k + exp_l).unsqueeze(0) @@ -455,21 +312,32 @@ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None) ubias -= (A.reshape(batch_size, -1) * self.max_input.reshape(batch_size, -1)).sum(dim=-1).unsqueeze(0) return [(None, uA)], 0, ubias else: - return super().bound_backward(last_lA, last_uA, x) + return super().bound_backward(last_lA, last_uA, x, **kwargs) - def bound_relax(self, x): + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) min_val = -1e9 l, u = x.lower.clamp(min=min_val), x.upper.clamp(min=min_val) - m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) + if self.opt_stage in ['opt', 'reuse']: + self.alpha[self._start].data[:2] = torch.min(torch.max( + self.alpha[self._start].data[:2], x.lower), x.upper) + m = torch.min(self.alpha[self._start], x.lower + 0.99) + else: + m = torch.min((x.lower + x.upper) / 2, x.lower + 0.99) exp_l, exp_m, exp_u = torch.exp(x.lower), torch.exp(m), torch.exp(x.upper) k = exp_m self.add_linear_relaxation(mask=None, type='lower', k=k, x0=m, y0=exp_m) - min_val = -1e9 # to avoid (-inf)-(-inf) when both input.lower and input.upper are -inf - epsilon = 1e-20 - close = (u - l < epsilon).int() - k = close * exp_u + (1 - close) * (exp_u - exp_l) / (u - l + epsilon) + k = (exp_u - exp_l) / (u - l).clamp(min=1e-8) self.add_linear_relaxation(mask=None, type='upper', k=k, x0=l, y0=exp_l) + def _init_opt_parameters_impl(self, size_spec, **kwargs): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + alpha = torch.empty(2, size_spec, *l.shape, device=l.device) + alpha.data[:2] = (l + u) / 2 + return alpha + class BoundLog(BoundActivation): @@ -479,7 +347,9 @@ def forward(self, x): return torch.logsumexp(self.inputs[0].inputs[0].inputs[0].forward_value, dim=-1) return torch.log(x.clamp(min=epsilon)) - def bound_relax(self, x): + def bound_relax(self, x, init=False): + if init: + self.init_linear_relaxation(x) rl, ru = self.forward(x.lower), self.forward(x.upper) ku = (ru - rl) / (x.upper - x.lower + epsilon) self.add_linear_relaxation(mask=None, type='lower', k=ku, x0=x.lower, y0=rl) @@ -497,7 +367,7 @@ def interval_propagate(self, *v): return lower, upper return super().interval_propagate(*v) - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None): + def bound_backward(self, last_lA, last_uA, x, **kwargs): A, lbias, ubias = super().bound_backward(last_lA, last_uA, x) # NOTE adhoc implementation for loss fusion if self.loss_fusion: @@ -507,49 +377,173 @@ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None) return A, lbias, ubias -class BoundPow(BoundActivation): +class BoundPow(BoundOptimizableNonLinear): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.ibp_intermediate = False + self.use_precompute = True + self.has_constraint = True + self.exponent = 2 + def act_func(x): + return torch.pow(x, self.exponent) + self.act_func = act_func + def d_act_func(x): + return self.exponent * torch.pow(x, self.exponent - 1) + self.d_act_func = d_act_func + def d2_act_func(x): + return self.exponent * (self.exponent - 1) * torch.pow(x, self.exponent - 2) + self.d2_act_func = d2_act_func + + def generate_inflections(self, lb, ub): + if self.exponent % 2: + self.inflections = [0.] + else: + self.extremes = [0.] + + def generate_d_lower_upper(self, lower, upper): + if self.exponent % 2: + # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed. + # Note that for neurons with also input lower bound >=0, they will be masked later. + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), + (upper / self.step_pre).to(torch.long).reshape(-1) + ) + 1 + # Lookup the lower bound slope from the pre-computed table. + d_upper = torch.index_select(self.d_upper, 0, index).view(lower.shape) + + # Indices of neurons with lower bound <=0, whose optimal slope to upper bound the function was pre-computed. + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + (lower / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + d_lower = torch.index_select(self.d_lower, 0, index).view(upper.shape) + return d_lower, d_upper + else: + return torch.zeros_like(upper), torch.zeros_like(upper) + + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + shape = [size_spec] + list(l.shape) + alpha = torch.empty(10, *shape, device=l.device) + alpha.data[:4] = ((l + u) / 2).unsqueeze(0).expand(4, *shape) + alpha.data[4:6] = self.tp_both_lower_init[name_start].expand(2, *shape) + alpha.data[6:8] = self.tp_both_upper_init[name_start].expand(2, *shape) + alpha.data[8:10] = torch.zeros(2, *shape) + return alpha + + @torch.no_grad() + def precompute_relaxation(self, func, dfunc, x_limit = 500): + """ + This function precomputes the tangent lines that will be used as lower/upper bounds for S-shapes functions. + """ + self.x_limit = x_limit + self.step_pre = 0.01 + self.num_points_pre = int(self.x_limit / self.step_pre) + max_iter = 100 + + def check_lower(upper, d): + """Given two points upper, d (d <= upper), check if the slope at d will be less than f(upper) at upper.""" + k = dfunc(d) + # Return True if the slope is a lower bound. + return k * (upper - d) + func(d) <= func(upper) + + def check_upper(lower, d): + """Given two points lower, d (d >= lower), check if the slope at d will be greater than f(lower) at lower.""" + k = dfunc(d) + # Return True if the slope is a upper bound. + return k * (lower - d) + func(d) >= func(lower) + + # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function. + upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) + r = torch.zeros_like(upper) + # Initial guess, the tangent line is at -1. + l = -torch.ones_like(upper) + while True: + # Check if the tangent line at the guessed point is an lower bound at f(upper). + checked = check_upper(upper, l).int() + # If the initial guess is not smaller enough, then double it (-2, -4, etc). + l = checked * l + (1 - checked) * (l * 2) + if checked.sum() == l.numel(): + break + # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper). + # We want to further tighten this bound by moving it closer to 0. + for _ in range(max_iter): + # Binary search. + m = (l + r) / 2 + checked = check_upper(upper, m).int() + l = checked * m + (1 - checked) * l + r = checked * r + (1 - checked) * m + # At upper, a line with slope l is guaranteed to lower bound the function. + self.d_upper = l.clone() + + # Do the same again: + # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function. + lower = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) + l = torch.zeros_like(upper) + r = torch.ones_like(upper) + while True: + checked = check_lower(lower, r).int() + r = checked * r + (1 - checked) * (r * 2) + if checked.sum() == l.numel(): + break + for _ in range(max_iter): + m = (l + r) / 2 + checked = check_lower(lower, m).int() + l = (1 - checked) * m + checked * l + r = (1 - checked) * r + checked * m + self.d_lower = r.clone() def forward(self, x, y): return torch.pow(x, y) - def bound_backward(self, last_lA, last_uA, x, y): + def bound_backward(self, last_lA, last_uA, x, y, start_node=None, + start_shape=None, **kwargs): assert not self.is_input_perturbed(1) - y = y.lower.item() - if y == int(y) and y == 2: - x_l = x.lower - x_u = torch.max(x.upper, x.lower + 1e-8) - - pow_l = self.forward(x_l, y) - pow_u = self.forward(x_u, y) - k_u = (pow_u - pow_l) / (x_u - x_l).clamp(min=1e-8) - b_u = pow_l - k_u * x_l - - k_l = torch.zeros_like(k_u) - b_l = torch.zeros_like(b_u) - x_m = (x_l + x_u) / 2 - - # TODO this only holds for y=2 - x_m = (x_u < 0) * torch.max(x_m, x_u * 2) + (x_l > 0) * torch.min(x_m, x_l * 2) - k_l = y * self.forward(x_m, y - 1) - b_l = self.forward(x_m, y) - k_l * x_m - - if last_lA is not None: - last_lA_pos, last_lA_neg = last_lA.clamp(min=0), last_lA.clamp(max=0) - lA = last_lA_pos * k_l + last_lA_neg * k_u - lb = self.get_bias(last_lA_pos, b_l) + self.get_bias(last_lA_neg, b_u) - else: - lA, lb = None, 0 - - if last_uA is not None: - last_uA_pos, last_uA_neg = last_uA.clamp(min=0), last_uA.clamp(max=0) - uA = last_uA_pos * k_u + last_uA_neg * k_l - ub = self.get_bias(last_uA_pos, b_u) + self.get_bias(last_uA_neg, b_l) - else: - uA, ub = None, 0 + self._start = start_node.name if start_node is not None else None + y = y.value + if y == int(y): + x.upper = torch.max(x.upper, x.lower + 1e-8) + self.exponent = int(y) + assert self.exponent >= 2 + if self.exponent % 2: + self.precompute_relaxation(self.act_func, self.d_act_func) + + As, lbias, ubias = super().bound_backward( + last_lA, last_uA, x, start_node, start_shape, **kwargs) + return [As[0], (None, None)], lbias, ubias + else: + raise NotImplementedError('Exponent is not supported yet') + + def bound_forward(self, dim_in, x, y): + assert y.lower == y.upper == int(y.lower) + y = y.lower + x.upper = torch.max(x.upper, x.lower + 1e-8) + self.exponent = int(y) + assert self.exponent >= 2 + if self.exponent % 2: + self.precompute_relaxation(self.act_func, self.d_act_func) + return super().bound_forward(dim_in, x) - return [(lA, uA), (None, None)], lb, ub + def bound_relax_impl(self, lb, ub): + if self.opt_stage in ['opt', 'reuse']: + if not hasattr(self, 'alpha'): + # Raise an error if alpha is not created. + self._no_bound_parameters() + ns = self._start + + self.alpha[ns].data[8:10] = torch.max( + torch.min(self.alpha[ns][8:10], ub), lb) + lb_point = self.alpha[ns][8:10] + lower_slope = self.d_act_func(lb_point) + lower_bias = self.act_func(lb_point) - lower_slope * lb_point else: - raise NotImplementedError(f'Exponent {y} is not supported yet') + lower_slope = 0 + lower_bias = 0 + + upper_slope = (self.act_func(ub) - self.act_func(lb)) / (ub - lb).clamp(min=1e-8) + upper_bias = self.act_func(ub) - ub * upper_slope + return lower_slope, lower_bias, upper_slope, upper_bias def interval_propagate(self, *v): assert not self.is_input_perturbed(1) @@ -564,91 +558,685 @@ def interval_propagate(self, *v): mask = 1 - ((v[0][0] < 0) * (v[0][1] > 0)).to(pl.dtype) return pl * mask, pu + def clamp_interim_bounds(self): + if self.exponent % 2 == 0: + self.cstr_lower = self.lower.clamp(min=0) + self.cstr_upper = self.upper.clamp(min=0) + self.cstr_interval = (self.cstr_lower, self.cstr_upper) + -class BoundReciprocal(BoundActivation): +class BoundReciprocal(BoundOptimizableActivation): + + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.splittable = True def forward(self, x): return torch.reciprocal(x) - def bound_relax(self, x): - m = (x.lower + x.upper) / 2 - kl = -1 / m.pow(2) - self.add_linear_relaxation(mask=None, type='lower', k=kl, x0=m, y0=1. / m) - ku = -1. / (x.lower * x.upper) - self.add_linear_relaxation(mask=None, type='upper', k=ku, x0=x.lower, y0=1. / x.lower) - def interval_propagate(self, *v): - h_L, h_U = v[0][0].float(), v[0][1].float() + h_L = v[0][0].to(dtype=torch.get_default_dtype()) + h_U = v[0][1].to(dtype=torch.get_default_dtype()) assert h_L.min() > 0, 'Only positive values are supported in BoundReciprocal' return torch.reciprocal(h_U), torch.reciprocal(h_L) + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) -class BoundSqrt(BoundActivation): + assert x.lower.min() > 0 - def forward(self, x): - return torch.sqrt(x) + ku = -1. / (x.lower * x.upper) + self.add_linear_relaxation(mask=None, type='upper', k=ku, x0=x.lower) - def bound_backward(self, last_lA, last_uA, x): - x_l = x.lower - x_u = torch.max(x.upper, x.lower + 1e-8) - sqrt_l = self.forward(x_l) - sqrt_u = self.forward(x_u) - k_l = (sqrt_u - sqrt_l) / (x_u - x_l).clamp(min=1e-8) - b_l = sqrt_l - k_l * x_l - - x_m = (x_l + x_u) / 2 - sqrt_m = self.forward(x_m) - k_u = -0.5 * torch.pow(x_m, -1.5) - b_u = sqrt_m - k_u * x_m - - # TODO make this part a general function - if last_lA is not None: - last_lA_pos, last_lA_neg = last_lA.clamp(min=0), last_lA.clamp(max=0) - lA = last_lA_pos * k_l + last_lA_neg * k_u - lb = self.get_bias(last_lA_pos, b_l) + self.get_bias(last_lA_neg, b_u) - else: - lA, lb = None, 0 - if last_uA is not None: - last_uA_pos, last_uA_neg = last_uA.clamp(min=0), last_uA.clamp(max=0) - uA = last_uA_pos * k_u + last_uA_neg * k_l - ub = self.get_bias(last_uA_pos, b_u) + self.get_bias(last_uA_neg, b_l) + if self.opt_stage in ['opt', 'reuse']: + self.alpha[self._start].data[:2] = torch.min(torch.max( + self.alpha[self._start].data[:2], x.lower), x.upper) + mid = self.alpha[self._start].clamp(min=0.01) else: - uA, ub = None, 0 + mid = (x.lower + x.upper) / 2 - return [(lA, uA), (None, None)], lb, ub + self.add_linear_relaxation( + mask=None, type='lower', k=-1./(mid**2), x0=mid) + def _init_opt_parameters_impl(self, size_spec, **kwargs): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + alpha = torch.empty(2, size_spec, *l.shape, device=l.device) + alpha.data[:2] = (l + u) / 2 + return alpha -class BoundSqr(BoundActivation): + +class BoundSqrt(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.use_prior_constraint = True + self.has_constraint = True def forward(self, x): - return x**2 + return torch.sqrt(x) - def bound_backward(self, last_lA, last_uA, x): - x_L, x_U = x.lower, x.upper - upper_k = x_U + x_L - upper_b = x_L**2 - upper_k * x_L - if last_uA is not None: - # Special case if we only want the upper bound with non-negative - # coefficients. - if last_uA.min() >= 0: - uA = last_uA * upper_k - ubias = self.get_bias(last_uA, upper_b) - else: - raise NotImplementedError + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + + if self.opt_stage in ['opt', 'reuse']: + self.alpha[self._start].data[:2] = torch.min(torch.max( + self.alpha[self._start].data[:2], x.lower), x.upper) + mid = self.alpha[self._start] else: - uA, ubias = None, 0 - if last_lA is not None: - if last_lA.max() <= 0: - lA = last_lA * upper_k - lbias = self.get_bias(last_lA, upper_b) - else: - raise NotImplementedError + mid = (x.lower + x.upper) / 2 + k = 0.5 / self.forward(mid) + self.add_linear_relaxation(mask=None, type='upper', k=k, x0=mid) + + sqrt_l = self.forward(x.lower) + sqrt_u = self.forward(x.upper) + k = (sqrt_u - sqrt_l) / (x.upper - x.lower).clamp(min=1e-8) + self.add_linear_relaxation(mask=None, type='lower', k=k, x0=x.lower) + + def bound_backward(self, last_lA, last_uA, x, **kwargs): + if self.use_prior_constraint and self.check_constraint_available(x): + if hasattr(x, 'cstr_interval'): + del x.cstr_interval + del x.cstr_lower + del x.cstr_upper + + x_l, x_u = self._ibp_constraint(x, delete_bounds_after_use=True) + x_u = torch.max(x_u, x_l + 1e-8) + return super().bound_backward(last_lA, last_uA, x, **kwargs) + + def clamp_interim_bounds(self): + self.cstr_lower = self.lower.clamp(min=0) + self.cstr_upper = self.upper.clamp(min=0) + self.cstr_interval = (self.cstr_lower, self.cstr_upper) + + def _init_opt_parameters_impl(self, size_spec, **kwargs): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + alpha = torch.empty(2, size_spec, *l.shape, device=l.device) + alpha.data[:2] = (l + u) / 2 + return alpha + + +class BoundSqr(BoundOptimizableActivation): + + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.splittable = True + + def forward(self, x): + return x**2 + + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + upper_k = x.lower + x.upper + # Upper bound: connect the two points (x_l, x_l^2) and (x_u, x_u^2). + # The upper bound should always be better than IBP. + self.add_linear_relaxation( + mask=None, type='upper', k=upper_k, x0=x.lower) + + if self.opt_stage in ['opt', 'reuse']: + mid = self.alpha[self._start] else: - lA, lbias = None, 0 - return [(lA, uA)], lbias, ubias + # Lower bound is a z=0 line if x_l and x_u have different signs. + # Otherwise, the lower bound is a tangent line at x_l. + # The lower bound should always be better than IBP. + # If both x_l and x_u < 0, select x_u. If both > 0, select x_l. + # If x_l < 0 and x_u > 0, we use the z=0 line as the lower bound. + mid = F.relu(x.lower) - F.relu(-x.upper) + + self.add_linear_relaxation(mask=None, type='lower', k=2*mid, x0=mid) + + def _init_opt_parameters_impl(self, size_spec, **kwargs): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + alpha = torch.empty(2, size_spec, *l.shape, device=l.device) + alpha.data[:2] = F.relu(l) - F.relu(-u) + return alpha def interval_propagate(self, *v): h_L, h_U = v[0][0], v[0][1] lower = ((h_U < 0) * (h_U**2) + (h_L > 0) * (h_L**2)) upper = torch.max(h_L**2, h_U**2) return lower, upper + + +class BoundMinMax(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.options = options + self.requires_input_bounds = [0, 1] + self.op = None + + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" + l = self.inputs[0].lower + # Alpha dimension is (8, output_shape, batch, *shape) for Tanh. + return torch.ones_like(l).unsqueeze(0).repeat(2, *[1] * l.ndim) + + def clip_alpha(self): + for v in self.alpha.values(): + v.data = torch.clamp(v.data, 0., 1.) + + def forward(self, x, y): + if self.op == 'max': + return torch.max(x, y) + elif self.op == 'min': + return torch.min(x, y) + else: + raise NotImplementedError + + def _backward_relaxation(self, last_lA, last_uA, x, y, start_node): + lb_x = x.lower + ub_x = x.upper + lb_y = y.lower + ub_y = y.upper + + ub_x = torch.max(ub_x, lb_x + 1e-8) + ub_y = torch.max(ub_y, lb_y + 1e-8) + + if self.opt_stage in ['opt', 'reuse']: + selected_alpha = self.alpha[start_node.name] + alpha_u = selected_alpha[0].squeeze(0) + alpha_l = selected_alpha[1].squeeze(0) + else: + alpha_u = alpha_l = 1 + + # Generate masks for stable and unstable neurons + # Neurons are stable when x, y bounds fall in z=x or z=y plane + x_mask = (lb_x >= ub_y).requires_grad_(False).to(lb_x.dtype) + y_mask = (lb_y >= ub_x).requires_grad_(False).to(lb_y.dtype) + no_mask = (1. - x_mask) * (1. - y_mask) + + # Calculate dx, dy, b coefficients according to https://www.overleaf.com/read/dbyyfpjhhwbk + if self.op == 'max': + upper_dx = x_mask + no_mask * ( + (ub_y - ub_x) / (alpha_u * (lb_x - ub_x))) + upper_dy = y_mask + no_mask * ( + (alpha_u - 1) * (ub_y - ub_x)) / (alpha_u * (ub_y - lb_y)) + upper_b = no_mask * ( + ub_x - (ub_x * (ub_y - ub_x)) / (alpha_u * (lb_x - ub_x)) + - ((alpha_u - 1) * (ub_y - ub_x) * lb_y) / ( + alpha_u * (ub_y - lb_y))) + lower_dx = x_mask + no_mask * alpha_l + lower_dy = y_mask + no_mask * (1 - alpha_l) + lower_b = None + elif self.op == 'min': + lower_dx = y_mask + no_mask * ( + (lb_x - lb_y) / (alpha_u * (lb_x - ub_x))) + lower_dy = x_mask + no_mask * ( + (alpha_u - 1) * (lb_x - lb_y)) / (alpha_u * (ub_y - lb_y)) + lower_b = no_mask * ( + lb_y - (ub_x * (lb_x - lb_y)) / (alpha_u * (lb_x - ub_x)) + - ((alpha_u - 1) * (lb_x - lb_y) * lb_y) / ( + alpha_u * (ub_y - lb_y))) + upper_dx = y_mask + no_mask * alpha_l + upper_dy = x_mask + no_mask * (1 - alpha_l) + upper_b = None + else: + raise NotImplementedError + + upper_dx = upper_dx.unsqueeze(0) + upper_dy = upper_dy.unsqueeze(0) + lower_dx = lower_dx.unsqueeze(0) + lower_dy = lower_dy.unsqueeze(0) + if upper_b is not None: + upper_b = upper_b.unsqueeze(0) + if lower_b is not None: + lower_b = lower_b.unsqueeze(0) + + return upper_dx, upper_dy, upper_b, lower_dx, lower_dy, lower_b + + def bound_backward(self, last_lA, last_uA, x=None, y=None, start_shape=None, + start_node=None, **kwargs): + # Get element-wise CROWN linear relaxations. + upper_dx, upper_dy, upper_b, lower_dx, lower_dy, lower_b = \ + self._backward_relaxation(last_lA, last_uA, x, y, start_node) + + # Choose upper or lower bounds based on the sign of last_A + def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): + if last_A is None: + return None, 0 + # Obtain the new linear relaxation coefficients based on the signs in last_A. + _A, _bias = multiply_by_A_signs(last_A, d_pos, d_neg, b_pos, b_neg) + if isinstance(last_A, Patches): + # Save the patch size, which will be used in init_slope() to determine the number of optimizable parameters. + A_prod = _A.patches + if start_node is not None: + # Regular patches. + self.patch_size[start_node.name] = A_prod.size() + return _A, _bias + + # In patches mode we might need an unfold. + # lower_dx, lower_dy, upper_dx, upper_dy, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None + upper_dx = maybe_unfold_patches(upper_dx, last_lA if last_lA is not None else last_uA) + upper_dy = maybe_unfold_patches(upper_dy, last_lA if last_lA is not None else last_uA) + lower_dx = maybe_unfold_patches(lower_dx, last_lA if last_lA is not None else last_uA) + lower_dy = maybe_unfold_patches(lower_dy, last_lA if last_lA is not None else last_uA) + upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA) + lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA) + + uAx, ubias = _bound_oneside(last_uA, upper_dx, lower_dx, upper_b, lower_b) + uAy, ubias = _bound_oneside(last_uA, upper_dy, lower_dy, upper_b, lower_b) + lAx, lbias = _bound_oneside(last_lA, lower_dx, upper_dx, lower_b, upper_b) + lAy, lbias = _bound_oneside(last_lA, lower_dy, upper_dy, lower_b, upper_b) + + return [(lAx, uAx), (lAy, uAy)], lbias, ubias + + def interval_propagate(self, *v): + h_Lx, h_Ux = v[0][0], v[0][1] + h_Ly, h_Uy = v[1][0], v[1][1] + return self.forward(h_Lx, h_Ly), self.forward(h_Ux, h_Uy) + +class BoundMax(BoundMinMax): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.op = 'max' + +class BoundMin(BoundMinMax): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.op = 'min' + + +class BoundGELU(BoundOptimizableNonLinear): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.ibp_intermediate = False + self.use_precompute = True + self.act_func = F.gelu + def d_act_func(x): + return 0.5 * (1 + torch.erf(x / np.sqrt(2))) + x * torch.exp(-0.5 * x ** 2) / np.sqrt(2 * torch.pi) + self.d_act_func = d_act_func + def d2_act_func(x): + return 2 * torch.exp(-0.5 * x ** 2) / np.sqrt(2 * torch.pi) - x ** 2 * torch.exp(-0.5 * x ** 2) / np.sqrt(2 * torch.pi) + self.d2_act_func = d2_act_func + self.precompute_relaxation('gelu', self.act_func, self.d_act_func) + + def _init_masks(self, x): + lower = x.lower + upper = x.upper + self.mask_left_pos = torch.logical_and(lower >= -np.sqrt(2), upper <= 0) + self.mask_left_neg = upper <= -np.sqrt(2) + self.mask_left = torch.logical_xor(upper <= 0, + torch.logical_or(self.mask_left_pos, self.mask_left_neg)) + + self.mask_right_pos = lower >= np.sqrt(2) + self.mask_right_neg = torch.logical_and(upper <= np.sqrt(2), lower >= 0) + self.mask_right = torch.logical_xor(lower >= 0, + torch.logical_or(self.mask_right_pos, self.mask_right_neg)) + + self.mask_2 = torch.logical_and(torch.logical_and(upper > 0, upper <= np.sqrt(2)), + torch.logical_and(lower < 0, lower >= -np.sqrt(2))) + self.mask_left_3 = torch.logical_and(lower < -np.sqrt(2), torch.logical_and( + upper > 0, upper <= np.sqrt(2))) + self.mask_right_3 = torch.logical_and(upper > np.sqrt(2), torch.logical_and( + lower < 0, lower >= -np.sqrt(2))) + self.mask_4 = torch.logical_and(lower < -np.sqrt(2), upper > np.sqrt(2)) + self.mask_both = torch.logical_or(self.mask_2, torch.logical_or(self.mask_4, + torch.logical_or(self.mask_left_3, self.mask_right_3))) + + @torch.no_grad() + def precompute_relaxation(self, name, func, dfunc, x_limit=1000): + """ + This function precomputes the tangent lines that will be used as + lower/upper bounds for S-shapes functions. + """ + self.x_limit = x_limit + self.step_pre = 0.01 + self.num_points_pre = int(self.x_limit / self.step_pre) + max_iter = 100 + + logger.debug('Precomputing relaxation for %s (pre-activation limit: %f)', + name, x_limit) + + def check_lower(upper, d): + """Given two points upper, d (d <= upper), check if the slope at d will be less than f(upper) at upper.""" + k = dfunc(d) + # Return True if the slope is a lower bound. + return k * (upper - d) + func(d) <= func(upper) + + def check_upper(lower, d): + """Given two points lower, d (d >= lower), check if the slope at d will be greater than f(lower) at lower.""" + k = dfunc(d) + # Return True if the slope is a upper bound. + return k * (lower - d) + func(d) >= func(lower) + + # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function. + upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) + np.sqrt(2) + r = torch.ones_like(upper) + # Initial guess, the tangent line is at -1. + l = -torch.ones_like(upper) + while True: + # Check if the tangent line at the guessed point is an lower bound at f(upper). + checked = check_lower(upper, l).int() + # If the initial guess is not smaller enough, then double it (-2, -4, etc). + l = checked * l + (1 - checked) * (l * 2) + if checked.sum() == l.numel(): + break + # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper). + # We want to further tighten this bound by moving it closer to 0. + for _ in range(max_iter): + # Binary search. + m = (l + r) / 2 + checked = check_lower(upper, m).int() + l = checked * m + (1 - checked) * l + r = checked * r + (1 - checked) * m + # At upper, a line with slope l is guaranteed to lower bound the function. + self.d_lower_right = l.clone() + + # Do the same again: + # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function. + lower = (-self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) + np.sqrt(2)).clamp(min=0.01) + l = torch.zeros_like(upper) + np.sqrt(2) + r = torch.zeros_like(upper) + x_limit + while True: + checked = check_upper(lower, r).int() + r = checked * r + (1 - checked) * (r * 2) + if checked.sum() == l.numel(): + break + for _ in range(max_iter): + m = (l + r) / 2 + checked = check_upper(lower, m).int() + l = (1 - checked) * m + checked * l + r = (1 - checked) * r + checked * m + self.d_upper_right = r.clone() + + upper = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) - np.sqrt(2) + r = torch.zeros_like(upper) - 0.7517916 + # Initial guess, the tangent line is at -1. + l = torch.zeros_like(upper) - np.sqrt(2) + while True: + checked = check_lower(upper, r).int() + r = checked * r + (1 - checked) * (r * 2) + if checked.sum() == l.numel(): + break + # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper). + # We want to further tighten this bound by moving it closer to 0. + for _ in range(max_iter): + # Binary search. + m = (l + r) / 2 + checked = check_lower(upper, m).int() + l = (1 - checked) * m + checked * l + r = (1 - checked) * r + checked * m + # At upper, a line with slope l is guaranteed to lower bound the function. + self.d_lower_left = r.clone() + + # Do the same again: + # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function. + lower = (self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) - np.sqrt(2)).clamp(max=0) + l = torch.zeros_like(upper) - x_limit + r = torch.zeros_like(upper) - np.sqrt(2) + while True: + checked = check_upper(lower, l).int() + l = checked * l + (1 - checked) * (l * 2) + if checked.sum() == l.numel(): + break + for _ in range(max_iter): + m = (l + r) / 2 + checked = check_upper(lower, m).int() + l = (1 - checked) * m + checked * l + r = (1 - checked) * r + checked * m + self.d_upper_left = r.clone() + + logger.debug('Done') + + def opt_init(self): + super().opt_init() + self.tp_right_lower_init = {} + self.tp_right_upper_init = {} + self.tp_left_lower_init = {} + self.tp_left_upper_init = {} + self.tp_both_lower_init = {} + + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + shape = [size_spec] + list(l.shape) + alpha = torch.empty(14, *shape, device=l.device) + alpha.data[:4] = ((l + u) / 2).unsqueeze(0).expand(4, *shape) + alpha.data[4:6] = self.tp_right_lower_init[name_start].expand(2, *shape) + alpha.data[6:8] = self.tp_right_upper_init[name_start].expand(2, *shape) + alpha.data[8:10] = self.tp_left_lower_init[name_start].expand(2, *shape) + alpha.data[10:12] = self.tp_left_upper_init[name_start].expand(2, *shape) + alpha.data[12:14] = self.tp_both_lower_init[name_start].expand(2, *shape) + return alpha + + def forward(self, x): + return F.gelu(x) + + def bound_relax_impl(self, x, func, dfunc): + lower, upper = x.lower, x.upper + y_l, y_u = func(lower), func(upper) + # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)). + k_direct = k = (y_u - y_l) / (upper - lower).clamp(min=1e-8) + + # Fixed bounds that cannot be optimized. self.mask_neg are the masks for neurons with upper bound <= 0. + # Upper bound for the case of input lower bound <= 0, is always the direct line. + self.add_linear_relaxation( + mask=torch.logical_or(torch.logical_or(self.mask_left_pos, + self.mask_right_neg), self.mask_both), type='upper', k=k_direct, x0=lower, y0=y_l) + # Lower bound for the case of input upper bound >= 0, is always the direct line. + self.add_linear_relaxation( + mask=torch.logical_or(self.mask_left_neg, + self.mask_right_pos), type='lower', k=k_direct, x0=lower, y0=y_l) + + # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed. + # Note that for neurons with also input lower bound >=0, they will be masked later. + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), + ((upper - np.sqrt(2)) / self.step_pre).to(torch.long).reshape(-1) + ) + 1 + if index.max() >= self.d_lower_right.numel(): + warnings.warn(f'Pre-activation bounds are too loose for {self}') + # Lookup the lower bound slope from the pre-computed table. + d_lower_right = torch.where( + (index < self.d_lower_right.numel()).view(lower.shape), + torch.index_select( + self.d_lower_right, 0, index.clamp(max=self.d_lower_right.numel() - 1) + ).view(lower.shape), + lower, + # If the pre-activation bounds are too loose, just use IBP. + # torch.ones_like(index).to(lower) * (-100.) + ) + else: + # Lookup the lower bound slope from the pre-computed table. + d_lower_right = torch.index_select( + self.d_lower_right, 0, index).view(lower.shape) + + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + ((lower + np.sqrt(2)) / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + if index.max() >= self.d_lower_left.numel(): + warnings.warn(f'Pre-activation bounds are too loose for {self}') + # Lookup the lower bound slope from the pre-computed table. + d_lower_left = torch.where( + (index < self.d_lower_left.numel()).view(upper.shape), + torch.index_select( + self.d_lower_left, 0, index.clamp(max=self.d_lower_left.numel() - 1) + ).view(lower.shape), + upper, + ).view(lower.shape) + else: + # Lookup the lower bound slope from the pre-computed table. + d_lower_left = torch.index_select( + self.d_lower_left, 0, index).view(lower.shape) + + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + ((lower - np.sqrt(2)) / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + if index.max() >= self.d_upper_right.numel(): + warnings.warn(f'Pre-activation bounds are too loose for {self}') + # Lookup the lower bound slope from the pre-computed table. + d_upper_right = torch.where( + (index < self.d_upper_right.numel()).view(upper.shape), + torch.index_select( + self.d_upper_right, 0, index.clamp(max=self.d_upper_right.numel() - 1) + ).view(upper.shape), + upper, + ) + else: + d_upper_right = torch.index_select( + self.d_upper_right, 0, index).view(upper.shape) + + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=lower.device), + ((upper + np.sqrt(2)) / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + if index.max() >= self.d_upper_left.numel(): + warnings.warn(f'Pre-activation bounds are too loose for {self}') + # Lookup the lower bound slope from the pre-computed table. + d_upper_left = torch.where( + (index < self.d_upper_left.numel()).view(upper.shape), + torch.index_select( + self.d_upper_left, 0, index.clamp(max=self.d_upper_left.numel() - 1) + ).view(upper.shape), + upper, + ) + else: + d_upper_left = torch.index_select( + self.d_upper_left, 0, index).view(upper.shape) + + if self.opt_stage in ['opt', 'reuse']: + if not hasattr(self, 'alpha'): + # Raise an error if alpha is not created. + self._no_bound_parameters() + ns = self._start + + # Clipping is done here rather than after `opt.step()` call + # because it depends on pre-activation bounds + self.alpha[ns].data[0:2] = torch.max( + torch.min(self.alpha[ns][0:2], upper), lower) + self.alpha[ns].data[2:4] = torch.max( + torch.min(self.alpha[ns][2:4], upper), lower) + self.alpha[ns].data[4:6] = torch.max( + torch.min(self.alpha[ns][4:6], d_lower_right), lower) + self.alpha[ns].data[6:8] = torch.max( + self.alpha[ns][6:8], d_upper_right) + self.alpha[ns].data[8:10] = torch.min( + torch.max(self.alpha[ns][8:10], d_lower_left), upper) + self.alpha[ns].data[10:12] = torch.min( + self.alpha[ns][10:12], d_upper_left) + self.alpha[ns].data[12:14] = torch.min( + torch.max(self.alpha[ns][12:14], d_lower_left), d_lower_right) + + # shape [2, out_c, n, c, h, w]. + tp_pos = self.alpha[ns][0:2] # For upper bound relaxation + tp_neg = self.alpha[ns][2:4] # For lower bound relaxation + tp_right_lower = self.alpha[ns][4:6] + tp_right_upper = self.alpha[ns][6:8] + tp_left_lower = self.alpha[ns][8:10] + tp_left_upper = self.alpha[ns][10:12] + tp_both_lower = self.alpha[ns][12:14] + + # No need to use tangent line, when the tangent point is at the left + # side of the preactivation lower bound. Simply connect the two sides. + mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(lower)) + self.add_linear_relaxation( + mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_or(self.mask_right_3, + torch.logical_xor(self.mask_right, mask_direct)), type='lower', + k=dfunc(tp_right_lower), x0=tp_right_lower) + mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(upper)) + self.add_linear_relaxation( + mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_or(self.mask_left_3, + torch.logical_xor(self.mask_left, mask_direct)), type='lower', + k=dfunc(tp_left_lower), x0=tp_left_lower) + + mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(upper)) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_right, mask_direct), type='upper', + k=dfunc(tp_right_upper), x0=tp_right_upper) + mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(lower)) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_left, mask_direct), type='upper', + k=dfunc(tp_left_upper), x0=tp_left_upper) + + self.add_linear_relaxation( + mask=self.mask_4, type='lower', k=dfunc(tp_both_lower), x0=tp_both_lower) + self.add_linear_relaxation( + mask=torch.logical_or(torch.logical_or(self.mask_left_pos, self.mask_right_neg), + self.mask_2), type='lower', k=dfunc(tp_neg), x0=tp_neg) + self.add_linear_relaxation( + mask=torch.logical_or(self.mask_right_pos, + self.mask_left_neg), type='upper', k=dfunc(tp_pos), x0=tp_pos) + else: + if self.opt_stage == 'init': + # Initialize optimizable slope. + tp_right_lower_init = d_lower_right.detach() + tp_right_upper_init = d_upper_right.detach() + tp_left_lower_init = d_lower_left.detach() + tp_left_upper_init = d_upper_left.detach() + tp_both_lower_init = d_lower_right.detach() + + ns = self._start + self.tp_right_lower_init[ns] = tp_right_lower_init + self.tp_right_upper_init[ns] = tp_right_upper_init + self.tp_left_lower_init[ns] = tp_left_lower_init + self.tp_left_upper_init[ns] = tp_left_upper_init + self.tp_both_lower_init[ns] = tp_both_lower_init + + # Not optimized (vanilla CROWN bound). + # Use the middle point slope as the lower/upper bound. Not optimized. + m = (lower + upper) / 2 + y_m = func(m) + k = dfunc(m) + # Lower bound is the middle point slope for the case input upper bound <= 0. + # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation(mask=torch.logical_or(torch.logical_or(self.mask_left_pos, self.mask_right_neg), + self.mask_2), type='lower', k=k, x0=m, y0=y_m) + # Upper bound is the middle point slope for the case input lower bound >= 0. + # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation(mask=torch.logical_or(self.mask_right_pos, + self.mask_left_neg), type='upper', k=k, x0=m, y0=y_m) + + # Now handle the case where input lower bound <=0 and upper bound >= 0. + # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound. + mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(lower)) + self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + # Otherwise we do not use the direct line, we use the d_lower slope. + self.add_linear_relaxation( + mask=torch.logical_or(torch.logical_or(self.mask_right_3, self.mask_4), + torch.logical_xor(self.mask_right, mask_direct)), type='lower', + k=dfunc(d_lower_right), x0=d_lower_right) + mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(upper)) + self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_or(self.mask_left_3, + torch.logical_xor(self.mask_left, mask_direct)), type='lower', + k=dfunc(d_lower_left), x0=d_lower_left) + + mask_direct = torch.logical_and(self.mask_right, k_direct < dfunc(upper)) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_right, mask_direct), type='upper', + k=dfunc(d_upper_right), x0=d_upper_right) + mask_direct = torch.logical_and(self.mask_left, k_direct > dfunc(lower)) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_left, mask_direct), type='upper', + k=dfunc(d_upper_left), x0=d_upper_left) + + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + self.bound_relax_impl(x, self.act_func, self.d_act_func) + + def interval_propagate(self, *v): + pl, pu = self.forward(v[0][0]), self.forward(v[0][1]) + pl, pu = torch.min(pl, pu), torch.max(pl, pu) + min_global = self.forward(torch.tensor(-0.7517916)) + pl, pu = torch.min(min_global, torch.min(pl, pu)), torch.max(pl, pu) + return pl, pu \ No newline at end of file diff --git a/auto_LiRPA/operators/normalization.py b/auto_LiRPA/operators/normalization.py index ec6abe2..1597326 100644 --- a/auto_LiRPA/operators/normalization.py +++ b/auto_LiRPA/operators/normalization.py @@ -1,8 +1,12 @@ """ Normalization operators""" import copy + +import torch + from .base import * from .solver_utils import grb + class BoundBatchNormalization(Bound): def __init__(self, attr, inputs, output_index, options, training): super().__init__(attr, inputs, output_index, options) @@ -53,7 +57,50 @@ def forward(self, x, w, b, m, v): result = w.view(*shape) * x + b.view(*shape) return result - def bound_backward(self, last_lA, last_uA, *x): + def bound_forward(self, dim_in, *x): + inp = x[0] + assert (x[1].lower == x[1].upper).all(), "unsupported forward bound with perturbed mean" + assert (x[2].lower == x[2].upper).all(), "unsupported forward bound with perturbed var" + weight, bias = x[1].lower, x[2].lower + if not self.training: + assert (x[3].lower == x[3].upper).all(), "unsupported forward bound with perturbed mean" + assert (x[4].lower == x[4].upper).all(), "unsupported forward bound with perturbed var" + self.current_mean = x[3].lower + self.current_var = x[4].lower + self._check_unused_mean_or_var() + if not self.use_affine: + weight = torch.ones_like(weight) + bias = torch.zeros_like(bias) + + + tmp_bias = bias - self.current_mean / torch.sqrt(self.current_var + self.eps) * weight + tmp_weight = weight / torch.sqrt(self.current_var + self.eps) + + # for debug: this checking is passed, i.e., we derived the forward bound + # from the following correct computation procedure + # tmp_x = ((x[0].lb + x[0].ub) / 2.).detach() + # expect_output = self(tmp_x, *[_.lower for _ in x[1:]]) + # tmp_weight = tmp_weight.view(*((1, -1) + (1,) * (tmp_x.ndim - 2))) + # tmp_bias = tmp_bias.view(*((1, -1) + (1,) * (tmp_x.ndim - 2))) + # computed_output = tmp_weight * tmp_x + tmp_bias + # assert torch.allclose(expect_output, computed_output, 1e-5, 1e-5) + + tmp_weight = tmp_weight.view(*((1, 1, -1) + (1,) * (inp.lw.ndim - 3))) + new_lw = torch.clamp(tmp_weight, min=0.) * inp.lw + torch.clamp(tmp_weight, max=0.) * inp.uw + new_uw = torch.clamp(tmp_weight, min=0.) * inp.uw + torch.clamp(tmp_weight, max=0.) * inp.lw + + tmp_weight = tmp_weight.view(*((1, -1) + (1,) * (inp.lb.ndim - 2))) + tmp_bias = tmp_bias.view(*((1, -1) + (1,) * (inp.lb.ndim - 2))) + new_lb = torch.clamp(tmp_weight, min=0.) * inp.lb + torch.clamp(tmp_weight, max=0.) * inp.ub + tmp_bias + new_ub = torch.clamp(tmp_weight, min=0.) * inp.ub + torch.clamp(tmp_weight, max=0.) * inp.lb + tmp_bias + + return LinearBound( + lw = new_lw, + lb = new_lb, + uw = new_uw, + ub = new_ub) + + def bound_backward(self, last_lA, last_uA, *x, **kwargs): assert not self.is_input_perturbed(1) and not self.is_input_perturbed(2), \ 'Weight perturbation is not supported for BoundBatchNormalization' @@ -88,7 +135,7 @@ def _bound_oneside(last_A): # tmp_weight has shape (c,), it will be applied on the (c,) dimension. patches = patches * tmp_weight.view(*([1] * (patches.ndim - 3)), -1, 1, 1) # Match with sparse or non-sparse patches. - next_A = Patches(patches, last_A.stride, last_A.padding, last_A.shape, identity=0, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) + next_A = last_A.create_similar(patches) # bias to size (c,), need expansion before unfold. bias = tmp_bias.view(-1,1,1).expand(self.input_shape[1:]).unsqueeze(0) @@ -116,7 +163,7 @@ def _bound_oneside(last_A): patches = patches[last_A.unstable_idx[0], :, last_A.unstable_idx[1], last_A.unstable_idx[2]] # Expand the batch dimension. patches = patches.expand(-1, last_A.shape[1], *([-1] * (patches.ndim - 2))) - next_A = Patches(patches, 1, 0, last_A.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape) + next_A = last_A.create_similar(patches, stride=1, padding=0, identity=0) if last_A.unstable_idx is not None: # Need to expand the bias and choose the selected ones. bias = tmp_bias.view(-1,1,1,1).expand(-1, 1, last_A.output_shape[2], last_A.output_shape[3]) @@ -168,7 +215,7 @@ def interval_propagate(self, *v): # interval_propagate() of the Linear layer may encounter input with different norms. norm, eps = Interval.get_perturbation(v[0])[:2] - if norm == np.inf: + if norm == torch.inf: center = tmp_weight.view(*shape) * mid + tmp_bias.view(*shape) deviation = tmp_weight_abs.view(*shape) * diff elif norm > 0: @@ -229,5 +276,7 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") new_layer_gurobi_vars.append(out_chan_vars) self.solver_vars = new_layer_gurobi_vars - # self.solver_constrs = new_layer_gurobi_constrs model.update() + + def update_requires_input_bounds(self): + self._check_weight_perturbation() diff --git a/auto_LiRPA/operators/pooling.py b/auto_LiRPA/operators/pooling.py index 4594067..6b41249 100644 --- a/auto_LiRPA/operators/pooling.py +++ b/auto_LiRPA/operators/pooling.py @@ -7,9 +7,8 @@ class BoundMaxPool(BoundOptimizableActivation): - #FIXME clean up needed - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) assert ('pads' not in attr) or (attr['pads'][0] == attr['pads'][2]) assert ('pads' not in attr) or (attr['pads'][1] == attr['pads'][3]) @@ -20,12 +19,12 @@ def __init__(self, attr, inputs, output_index, options): self.padding = [attr['pads'][0], attr['pads'][1]] self.ceil_mode = False self.use_default_ibp = True - self.alpha = None + self.alpha = {} self.init = {} - self.alpha_batch_dim = 2 def forward(self, x): - output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + output, _ = F.max_pool2d(x, self.kernel_size, self.stride, self.padding, + return_indices=True, ceil_mode=self.ceil_mode) return output def project_simplex(self, patches): @@ -33,33 +32,35 @@ def project_simplex(self, patches): sorted, _ = torch.sort(sorted, -1, descending=True) rho_sum = torch.cumsum(sorted, -1) rho_value = 1 - rho_sum - rho_value = (sorted + rho_value/torch.tensor(range(1, sorted.size(-1)+1), dtype=torch.float, device=sorted.device)) > 0 + rho_value = (sorted + rho_value/torch.tensor( + range(1, sorted.size(-1)+1), dtype=torch.float, + device=sorted.device)) > 0 _, rho_index = torch.max(torch.cumsum(rho_value, -1), -1) rho_sum = torch.gather(rho_sum, -1, rho_index.unsqueeze(-1)).squeeze(-1) lbd = 1/(rho_index+1)* (1-rho_sum) return torch.clamp(patches + lbd.unsqueeze(-1).unsqueeze(-1), min=0) - def init_opt_parameters(self, start_nodes): - self.alpha = OrderedDict() + def _init_opt_parameters_impl(self, size_spec, name_start): + if name_start == '_forward': + warnings.warn("MaxPool's optimization is not supported for forward mode") + return None ref = self.inputs[0].lower # a reference variable for getting the shape - for ns, size_s, unstable_idx in start_nodes: - if ns == '_forward': - warnings.warn("MaxPool's optimization is not supported for forward mode") - continue - self.alpha[ns] = torch.empty( - [1, size_s, self.input_shape[0], self.input_shape[1], - self.output_shape[-2], self.output_shape[-1], - self.kernel_size[0], self.kernel_size[1]], - dtype=torch.float, device=ref.device, requires_grad=True) - self.init[ns] = False + alpha = torch.empty( + [1, size_spec, self.input_shape[0], self.input_shape[1], + self.output_shape[-2], self.output_shape[-1], + self.kernel_size[0], self.kernel_size[1]], + dtype=torch.float, device=ref.device, requires_grad=True) + self.init[name_start] = False + return alpha @staticmethod @torch.jit.script def jit_mutiply(Apos, Aneg, pos, neg): return pos.contiguous() * Apos + neg.contiguous() * Aneg - def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, unstable_idx=None): + def bound_backward(self, last_lA, last_uA, x, start_node=None, + unstable_idx=None, **kwargs): # self.padding is a tuple of two elements: (height dimension padding, width dimension padding). paddings = tuple((self.padding[0], self.padding[0], self.padding[1], self.padding[1])) @@ -80,31 +81,45 @@ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, # Find the maxpool neuron whose input bounds satisfy l_i > max_j u_j for all j != i. In this case, the maxpool neuron is linear, and we can set upper_d = lower_d = 1. # We first find which indices has the largest lower bound. - max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + max_lower, max_lower_index = F.max_pool2d( + x.lower, self.kernel_size, self.stride, self.padding, + return_indices=True, ceil_mode=self.ceil_mode) # Set the upper bound of the i-th input to -inf so it will not be selected as the max. if paddings == (0,0,0,0): delete_upper = torch.scatter( torch.flatten(x.upper, -2), -1, - torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) + torch.flatten(max_lower_index, -2), -torch.inf).view(upper_d.shape) else: - delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) + delete_upper = torch.scatter( + torch.flatten(F.pad(x.upper, paddings), -2), -1, + torch.flatten(max_lower_index, -2), + -torch.inf).view(upper_d.shape) # Find the the max upper bound over the remaining ones. - max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode) + max_upper, _ = F.max_pool2d( + delete_upper, self.kernel_size, self.stride, 0, + return_indices=True, ceil_mode=self.ceil_mode) # The upper bound slope for maxpool is either 1 on input satisfies l_i > max_j u_j (linear), or 0 everywhere. Upper bound is not optimized. values = torch.zeros_like(max_lower) values[max_lower >= max_upper] = 1.0 - upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape) + upper_d = torch.scatter( + torch.flatten(upper_d, -2), -1, + torch.flatten(max_lower_index, -2), + torch.flatten(values, -2)).view(upper_d.shape) if self.opt_stage == 'opt': if unstable_idx is not None and self.alpha[start_node.name].size(1) != 1: - if unstable_idx.ndim == 1: + if isinstance(unstable_idx, tuple): + raise NotImplementedError('Please use --conv_mode matrix') + elif unstable_idx.ndim == 1: # Only unstable neurons of the start_node neurons are used. - alpha = self.non_deter_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) + alpha = self.non_deter_index_select( + self.alpha[start_node.name], index=unstable_idx, dim=1) elif unstable_idx.ndim == 2: # Each element in the batch selects different neurons. - alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) + alpha = batched_index_select( + self.alpha[start_node.name], index=unstable_idx, dim=1) else: raise ValueError else: @@ -113,37 +128,51 @@ def bound_backward(self, last_lA, last_uA, x, start_node=None, start_shape=None, if not self.init[start_node.name]: lower_d = torch.zeros((shape), device=x.device) # [batch, C, H, W] - lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + lower_d = torch.scatter( + torch.flatten(lower_d, -2), -1, + torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) # shape [batch, C*k*k, L] - lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride) + lower_d_unfold = F.unfold( + lower_d, self.kernel_size, 1, stride=self.stride) # [batch, C, k, k, out_H, out_W] - alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1]) + alpha_data = lower_d_unfold.view( + lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], + self.kernel_size[1], self.output_shape[-2], self.output_shape[-1]) # [batch, C, out_H, out_W, k, k] alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach()) self.init[start_node.name] = True # In optimization mode, we use the same lower_d once builded. if self.padding[0] > 0 or self.padding[1] > 0: - lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]] + lower_d = lower_d[...,self.padding[0]:-self.padding[0], + self.padding[1]:-self.padding[1]] # The lower bound coefficients must be positive and projected to an unit simplex. alpha.data = self.project_simplex(alpha.data).clone().detach() # TODO: don't do this, never re-assign the .data property. Use copy_ instead. # permute the last 6 dimensions of alpha to [batch, C, k, k, out_H, out_W], which prepares for the unfold operation. alpha = alpha.permute((0,1,2,3,6,7,4,5)) alpha_shape = alpha.shape - alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1])) - lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride) - lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:]) + alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], + -1, alpha_shape[-2]*alpha_shape[-1])) + lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, + self.padding, self.stride) + lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], + alpha_shape[2], *lower_d.shape[1:]) lower_d = lower_d.squeeze(0) else: lower_d = torch.zeros((shape), device=x.device) # Not optimizable bounds. We simply set \hat{z} >= z_i where i is the input element with largest lower bound. - lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, + torch.flatten(max_lower_index, -2), + 1.0).view(upper_d.shape) if self.padding[0] > 0 or self.padding[1] > 0: - lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]] + lower_d = lower_d[...,self.padding[0]:-self.padding[0], + self.padding[1]:-self.padding[1]] # For the upper bound, we set the bias term to concrete upper bounds for maxpool neurons that are not linear. - max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, + self.padding, return_indices=True, + ceil_mode=self.ceil_mode) upper_b[max_upper > max_lower] = max_upper_[max_upper > max_lower] def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): @@ -169,40 +198,55 @@ def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): d_pos = F.pad(d_pos, padding) d_neg = F.pad(d_neg, padding) - pos_A = F.interpolate(pos_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size) + pos_A = F.interpolate( + pos_A.view(shape[0] * shape[1], *shape[2:]), + scale_factor=self.kernel_size) if d_pos.shape[-2] > pos_A.shape[-2] or d_pos.shape[-1] > pos_A.shape[-1]: if not (d_pos.shape[-2] > pos_A.shape[-2] and d_pos.shape[-1] > pos_A.shape[-1]): - raise NotImplementedError("Asymmetric padding of maxpool not implemented.") - pos_A = F.pad(pos_A, (0, d_pos.shape[-2] - pos_A.shape[-2], 0, d_pos.shape[-1] - pos_A.shape[-1])) + raise NotImplementedError( + "Asymmetric padding of maxpool not implemented.") + pos_A = F.pad(pos_A, (0, d_pos.shape[-2] - pos_A.shape[-2], + 0, d_pos.shape[-1] - pos_A.shape[-1])) else: - d_pos = F.pad(d_pos, (0, pos_A.shape[-2] - d_pos.shape[-2], 0, pos_A.shape[-1] - d_pos.shape[-1])) + d_pos = F.pad(d_pos, (0, pos_A.shape[-2] - d_pos.shape[-2], + 0, pos_A.shape[-1] - d_pos.shape[-1])) pos_A = pos_A.view(shape[0], shape[1], *pos_A.shape[1:]) - neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]), scale_factor=self.kernel_size) + neg_A = F.interpolate(neg_A.view(shape[0] * shape[1], *shape[2:]), + scale_factor=self.kernel_size) if d_neg.shape[-2] > neg_A.shape[-2] or d_neg.shape[-1] > neg_A.shape[-1]: if not (d_neg.shape[-2] > neg_A.shape[-2] and d_neg.shape[-1] > neg_A.shape[-1]): raise NotImplementedError("Asymmetric padding of maxpool not implemented.") - neg_A = F.pad(neg_A, (0, d_neg.shape[-2] - neg_A.shape[-2], 0, d_neg.shape[-1] - neg_A.shape[-1])) + neg_A = F.pad(neg_A, (0, d_neg.shape[-2] - neg_A.shape[-2], + 0, d_neg.shape[-1] - neg_A.shape[-1])) else: - d_neg = F.pad(d_neg, (0, neg_A.shape[-2] - d_neg.shape[-2], 0, neg_A.shape[-1] - d_neg.shape[-1])) + d_neg = F.pad(d_neg, (0, neg_A.shape[-2] - d_neg.shape[-2], + 0, neg_A.shape[-1] - d_neg.shape[-1])) neg_A = neg_A.view(shape[0], shape[1], *neg_A.shape[1:]) next_A = self.jit_mutiply(pos_A, neg_A, d_pos, d_neg) if self.padding[0] > 0 or self.padding[1] > 0: - next_A = next_A[...,self.padding[0]:-self.padding[0], self.padding[1]:-self.padding[1]] + next_A = next_A[...,self.padding[0]:-self.padding[0], + self.padding[1]:-self.padding[1]] elif isinstance(last_A, Patches): # The last_A.patches was not padded, so we need to pad them here. # If this Conv layer is followed by a ReLU layer, then the padding was already handled there and there is no need to pad again. - one_d = torch.ones(tuple(1 for i in self.output_shape[1:]), device=last_A.patches.device, dtype=last_A.patches.dtype).expand(self.output_shape[1:]) + one_d = torch.ones(tuple(1 for i in self.output_shape[1:]), + device=last_A.patches.device, dtype=last_A.patches.dtype).expand(self.output_shape[1:]) # Add batch dimension. one_d = one_d.unsqueeze(0) # After unfolding, the shape is (1, out_h, out_w, in_c, h, w) - one_d_unfolded = inplace_unfold(one_d, kernel_size=last_A.patches.shape[-2:], stride=last_A.stride, padding=last_A.padding, inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding) + one_d_unfolded = inplace_unfold( + one_d, kernel_size=last_A.patches.shape[-2:], + stride=last_A.stride, padding=last_A.padding, + inserted_zeros=last_A.inserted_zeros, + output_padding=last_A.output_padding) if last_A.unstable_idx is not None: # Move out_h, out_w dimension to the front for easier selection. one_d_unfolded_r = one_d_unfolded.permute(1, 2, 0, 3, 4, 5) # for sparse patches the shape is (unstable_size, batch, in_c, h, w). Batch size is 1 so no need to select here. - one_d_unfolded_r = one_d_unfolded_r[last_A.unstable_idx[1], last_A.unstable_idx[2]] + one_d_unfolded_r = one_d_unfolded_r[ + last_A.unstable_idx[1], last_A.unstable_idx[2]] else: # Append the spec dimension. one_d_unfolded_r = one_d_unfolded.unsqueeze(0) @@ -243,31 +287,23 @@ def upsample(last_patches, last_A): pos_A = upsample(pos_A, last_A) neg_A = upsample(neg_A, last_A) - stride = self.stride[0] * last_A.stride - if isinstance(last_A.padding, int): - padding = last_A.padding * self.stride[0] + self.padding[0] - else: - # Here we need to unfold the d_pos to match pos_A and neg_A patches - # And we compute the padding and stride of pos_A and neg_A - padding = tuple(a * self.stride[0] + self.padding[0] for a in last_A.padding) - padding, stride, output_padding = compute_patches_stride_padding( - self.input_shape, last_A.padding, last_A.stride, self.padding, self.stride, last_A.inserted_zeros, last_A.output_padding) + self.input_shape, last_A.padding, last_A.stride, self.padding, + self.stride, last_A.inserted_zeros, last_A.output_padding) pos_A.padding, pos_A.stride, pos_A.output_padding = padding, stride, output_padding neg_A.padding, neg_A.stride, neg_A.output_padding = padding, stride, output_padding # unsqueeze for the spec dimension - d_pos = maybe_unfold_patches(d_pos.unsqueeze(0), pos_A) - d_neg = maybe_unfold_patches(d_neg.unsqueeze(0), neg_A) - + d_pos = maybe_unfold_patches(d_pos.unsqueeze(0), pos_A) + d_neg = maybe_unfold_patches(d_neg.unsqueeze(0), neg_A) - next_A_patches = self.jit_mutiply(pos_A.patches, neg_A.patches, d_pos, d_neg) + next_A_patches = self.jit_mutiply( + pos_A.patches, neg_A.patches, d_pos, d_neg) if start_node is not None: self.patch_size[start_node.name] = next_A_patches.size() - next_A = Patches( next_A_patches, stride, padding, next_A_patches.shape, unstable_idx=last_A.unstable_idx, output_shape=last_A.output_shape, @@ -276,7 +312,8 @@ def upsample(last_patches, last_A): return next_A, bias if self.padding[0] > 0: - upper_d = upper_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + upper_d = upper_d[...,self.padding[0]:-self.padding[0], + self.padding[0]:-self.padding[0]] uA, ubias = _bound_oneside(last_uA, upper_d, lower_d, upper_b, lower_b) lA, lbias = _bound_oneside(last_lA, lower_d, upper_d, lower_b, upper_b) @@ -284,7 +321,7 @@ def upsample(last_patches, last_A): return [(lA, uA)], lbias, ubias def bound_forward(self, dim_in, x): - lower_d, lower_b, upper_d, upper_b = self.bound_relax(x) + lower_d, lower_b, upper_d, upper_b = self.bound_relax(x, init=False) def _bound_oneside(w_pos, b_pos, w_neg, b_neg, d, b): d_pos, d_neg = d.clamp(min=0), d.clamp(max=0) @@ -309,7 +346,10 @@ def _bound_oneside(w_pos, b_pos, w_neg, b_neg, d, b): return LinearBound(lw, lb, uw, ub) - def bound_relax(self, x): + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + # Only used by forward mode paddings = tuple(self.padding + self.padding) self.upper, self.lower = x.upper, x.lower @@ -328,41 +368,27 @@ def bound_relax(self, x): # 1. find the index i where li > uj for all j, then set upper_d = lower_d = 1 max_lower, max_lower_index = F.max_pool2d(x.lower, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) - delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -np.inf).view(upper_d.shape) + delete_upper = torch.scatter(torch.flatten(F.pad(x.upper, paddings), -2), -1, torch.flatten(max_lower_index, -2), -torch.inf).view(upper_d.shape) max_upper, _ = F.max_pool2d(delete_upper, self.kernel_size, self.stride, 0, return_indices=True, ceil_mode=self.ceil_mode) values = torch.zeros_like(max_lower) values[max_lower >= max_upper] = 1.0 upper_d = torch.scatter(torch.flatten(upper_d, -2), -1, torch.flatten(max_lower_index, -2), torch.flatten(values, -2)).view(upper_d.shape) - # FIXME shape error - if False and self.opt_stage == 'opt': - alpha = self.alpha[self._start] - - if self.init[self._start] == False: - lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) - lower_d_unfold = F.unfold(lower_d, self.kernel_size, 1, stride=self.stride) - - alpha_data = lower_d_unfold.view(lower_d.shape[0], lower_d.shape[1], self.kernel_size[0], self.kernel_size[1], self.output_shape[-2], self.output_shape[-1]) - alpha.data.copy_(alpha_data.permute((0,1,4,5,2,3)).clone().detach()) - self.init[self._start] = True - if self.padding[0] > 0: - lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] - - alpha.data = self.project_simplex(alpha.data).clone().detach() - alpha = alpha.permute((0,1,2,3,6,7,4,5)) - alpha_shape = alpha.shape - alpha = alpha.reshape((alpha_shape[0]*alpha_shape[1]*alpha_shape[2], -1, alpha_shape[-2]*alpha_shape[-1])) - lower_d = F.fold(alpha, self.input_shape[-2:], self.kernel_size, 1, self.padding, self.stride) - lower_d = lower_d.view(alpha_shape[0], alpha_shape[1], alpha_shape[2], *lower_d.shape[1:]) - lower_d = lower_d.squeeze(0) + if self.opt_stage == 'opt': + raise NotImplementedError else: - lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, torch.flatten(max_lower_index, -2), 1.0).view(upper_d.shape) + lower_d = torch.scatter(torch.flatten(lower_d, -2), -1, + torch.flatten(max_lower_index, -2), + 1.0).view(upper_d.shape) if self.padding[0] > 0: - lower_d = lower_d[...,self.padding[0]:-self.padding[0], self.padding[0]:-self.padding[0]] + lower_d = lower_d[...,self.padding[0]:-self.padding[0], + self.padding[0]:-self.padding[0]] values[:] = 0.0 - max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, self.padding, return_indices=True, ceil_mode=self.ceil_mode) + max_upper_, _ = F.max_pool2d(x.upper, self.kernel_size, self.stride, + self.padding, return_indices=True, + ceil_mode=self.ceil_mode) values[max_upper > max_lower] = max_upper_[max_upper > max_lower] upper_b = values @@ -371,6 +397,15 @@ def bound_relax(self, x): return lower_d, lower_b, upper_d, upper_b + def dump_optimized_params(self): + ret = {'alpha': self.alpha} + ret['init'] = self.init + return ret + + def restore_optimized_params(self, alpha): + self.alpha = alpha['alpha'] + self.init = alpha['init'] + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): # e.g., last layer input gurobi vars (3,32,32) gvars_array = np.array(v[0]) @@ -407,7 +442,8 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") a = model.addVar(vtype=grb.GRB.BINARY) a_sum += a model.addConstr(v >= var) - model.addConstr(v <= var + (1 - a) * pre_ubs[0, out_chan_idx, out_row_idx, out_col_idx]) + model.addConstr(v <= var + (1 - a) * pre_ubs[ + 0, out_chan_idx, out_row_idx, out_col_idx]) model.addConstr(a_sum == 1, name=f'lay{self.name}_{neuron_idx}_eq') out_row_vars.append(v) out_chan_vars.append(out_row_vars) @@ -419,14 +455,14 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") class BoundGlobalAveragePool(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) def forward(self, x): output = nn.AdaptiveAvgPool2d((1, 1)).forward(x) # adaptiveAveragePool with output size (1, 1) return output - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): H, W = self.input_shape[-2], self.input_shape[-1] lA = (last_lA.expand(list(last_lA.shape[:-2]) + [H, W]) / (H * W)) if last_lA is not None else None @@ -442,7 +478,7 @@ def interval_propagate(self, *v): class BoundAveragePool(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): # assumptions: ceil_mode=False, count_include_pad=True super().__init__(attr, inputs, output_index, options) @@ -467,7 +503,7 @@ def forward(self, x): return F.avg_pool2d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): def _bound_oneside(last_A): if last_A is None: return None, 0 @@ -483,21 +519,31 @@ def _bound_oneside(last_A): patches = last_A.patches shape = patches.size() # When the number of inserted zeros can cancel out the stride, we use a shortcut that can reduce computation. - simplify_patch = (last_A.inserted_zeros + 1 == self.kernel_size[0]) and (self.kernel_size[0] == self.kernel_size[1]) + simplify_patch = ((last_A.inserted_zeros + 1 == self.kernel_size[0]) + and (self.kernel_size[0] == self.kernel_size[1])) padding, stride, output_padding = compute_patches_stride_padding( - self.input_shape, last_A.padding, last_A.stride, self.padding, self.stride, - inserted_zeros=last_A.inserted_zeros, output_padding=last_A.output_padding, simplify=not simplify_patch) + self.input_shape, last_A.padding, last_A.stride, + self.padding, self.stride, + inserted_zeros=last_A.inserted_zeros, + output_padding=last_A.output_padding, + simplify=not simplify_patch) inserted_zeros = last_A.inserted_zeros if last_A.inserted_zeros == 0: # No inserted zeros, can be handled using interpolate. if last_A.unstable_idx is None: # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W] - up_sampled_patches = F.interpolate(patches.view(shape[0] * shape[1], shape[2] * shape[3], *shape[4:]), scale_factor=[1,] + self.kernel_size) + up_sampled_patches = F.interpolate( + patches.view(shape[0] * shape[1], + shape[2] * shape[3], *shape[4:]), + scale_factor=[1,] + self.kernel_size) # The dimension of patch-H and patch_W has changed. - up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) + up_sampled_patches = up_sampled_patches.view( + *shape[:-2], up_sampled_patches.size(-2), + up_sampled_patches.size(-1)) else: # shape is: [spec, batch, in_c, patch_H, patch_W] - up_sampled_patches = F.interpolate(patches, scale_factor=[1,] + self.kernel_size) + up_sampled_patches = F.interpolate( + patches, scale_factor=[1,] + self.kernel_size) # Divided by the averaging factor. up_sampled_patches = up_sampled_patches / prod(self.kernel_size) elif simplify_patch: @@ -507,27 +553,50 @@ def _bound_oneside(last_A): inserted_zeros = 0 value = 1. / prod(self.kernel_size) # In the case where the stride and adding_zeros cancel out, we do not need to insert zeros. - weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), fill_value=value, dtype=patches.dtype, device=patches.device) + weight = torch.full( + size=(self.input_shape[1], 1, *self.kernel_size), + fill_value=value, dtype=patches.dtype, + device=patches.device) if last_A.unstable_idx is None: # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W] - up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), weight, stride=1, groups=self.input_shape[1]) + up_sampled_patches = F.conv_transpose2d( + patches.reshape( + shape[0] * shape[1] * shape[2] * shape[3], + *shape[4:] + ), weight, stride=1, groups=self.input_shape[1]) else: # shape is: [spec, batch, in_c, patch_H, patch_W] - up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1], *shape[2:]), weight, stride=1, groups=self.input_shape[1]) - up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) + up_sampled_patches = F.conv_transpose2d( + patches.reshape(shape[0] * shape[1], *shape[2:]), + weight, stride=1, groups=self.input_shape[1]) + up_sampled_patches = up_sampled_patches.view( + *shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) else: # With inserted zeros, must be handled by treating pooling as general convolution. value = 1. / prod(self.kernel_size) - weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), fill_value=value, dtype=patches.dtype, device=patches.device) + weight = torch.full(size=(self.input_shape[1], 1, *self.kernel_size), + fill_value=value, dtype=patches.dtype, + device=patches.device) weight = insert_zeros(weight, last_A.inserted_zeros) if last_A.unstable_idx is None: # shape is: [out_C, batch, out_H, out_W, in_c, patch_H, patch_W] - up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), weight, stride=self.kernel_size, groups=self.input_shape[1]) + up_sampled_patches = F.conv_transpose2d( + patches.reshape(shape[0] * shape[1] * shape[2] * shape[3], *shape[4:]), + weight, stride=self.kernel_size, + groups=self.input_shape[1]) else: # shape is: [spec, batch, in_c, patch_H, patch_W] - up_sampled_patches = F.conv_transpose2d(patches.reshape(shape[0] * shape[1], *shape[2:]), weight, stride=self.kernel_size, groups=self.input_shape[1]) - up_sampled_patches = up_sampled_patches.view(*shape[:-2], up_sampled_patches.size(-2), up_sampled_patches.size(-1)) - next_A = last_A.create_similar(up_sampled_patches, stride=stride, padding=padding, output_padding=output_padding, inserted_zeros=inserted_zeros) + up_sampled_patches = F.conv_transpose2d( + patches.reshape(shape[0] * shape[1], *shape[2:]), + weight, stride=self.kernel_size, + groups=self.input_shape[1]) + up_sampled_patches = up_sampled_patches.view( + *shape[:-2], up_sampled_patches.size(-2), + up_sampled_patches.size(-1)) + next_A = last_A.create_similar( + up_sampled_patches, stride=stride, padding=padding, + output_padding=output_padding, + inserted_zeros=inserted_zeros) else: raise ValueError(f'last_A has unexpected shape {type(last_A)}') return next_A, 0. @@ -543,7 +612,9 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") pre_layer_shape = np.expand_dims(gvars_array, axis=0).shape # this layer shape (1,32,6,6) this_layer_shape = self.output_shape - assert this_layer_shape[2] == ((2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1))//self.stride[0]) + assert this_layer_shape[2] == ( + (2 * self.padding[0] + pre_layer_shape[2] - (self.stride[0] - 1) + ) // self.stride[0]) value = 1.0/(self.kernel_size[0] * self.kernel_size[1]) new_layer_gurobi_vars = [] diff --git a/auto_LiRPA/operators/reduce.py b/auto_LiRPA/operators/reduce.py index 6a57d49..6b59ae0 100644 --- a/auto_LiRPA/operators/reduce.py +++ b/auto_LiRPA/operators/reduce.py @@ -2,30 +2,47 @@ from .base import * -class BoundReduceMax(Bound): - def __init__(self, attr, inputs, output_index, options): +class BoundReduce(Bound): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) - self.axis = attr['axes'] - # for torch.max, `dim` must be an int - if isinstance(self.axis, list): - assert len(self.axis) == 1 - self.axis = self.axis[0] + self.axis = attr.get('axes', None) self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True self.use_default_ibp = True + def _parse_input_and_axis(self, *x): + if len(x) > 1: + assert not self.is_input_perturbed(1) + self.axis = tuple(x[1]) + self.axis = self.make_axis_non_negative(self.axis) + return x[0] + + def _return_bound_backward(self, lA, uA): + return [(lA, uA)] + [(None, None)] * (len(self.inputs) - 1), 0, 0 + + +class BoundReduceMax(BoundReduce): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) """Assume that the indexes with the maximum values are not perturbed. This generally doesn't hold true, but can still be used for the input shift in Softmax of Transformers.""" self.fixed_max_index = options.get('fixed_reducemax_index', False) - def forward(self, x): - self.axis = self.make_axis_non_negative(self.axis) - assert self.axis > 0 + def _parse_input_and_axis(self, *x): + x = super()._parse_input_and_axis(*x) + # for torch.max, `dim` must be an int + if isinstance(self.axis, tuple): + assert len(self.axis) == 1 + self.axis = self.axis[0] + return x + + def forward(self, *x): + x = self._parse_input_and_axis(*x) res = torch.max(x, dim=self.axis, keepdim=self.keepdim) self.indices = res.indices return res.values - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, *args, **kwargs): if self.fixed_max_index: def _bound_oneside(last_A): if last_A is None: @@ -38,30 +55,32 @@ def _bound_oneside(last_A): shape = list(last_A.shape) shape[self.axis + 1] *= self.input_shape[self.axis] A = torch.zeros(shape, device=last_A.device) + indices = indices.expand(*last_A.shape) A.scatter_(dim=self.axis + 1, index=indices, src=last_A) return A - return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0 + return self._return_bound_backward(_bound_oneside(last_lA), + _bound_oneside(last_uA)) else: - raise NotImplementedError('`bound_backward` for BoundReduceMax with perturbed maximum indexes is not implemented.') + raise NotImplementedError( + '`bound_backward` for BoundReduceMax with perturbed maximum' + 'indexes is not implemented.') -class BoundReduceMean(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.axis = attr['axes'] - self.keepdim = bool(attr['keepdims']) if 'keepdims' in attr else True - self.use_default_ibp = True +class BoundReduceMin(BoundReduceMax): + def forward(self, *x): + x = self._parse_input_and_axis(*x) + res = torch.min(x, dim=self.axis, keepdim=self.keepdim) + self.indices = res.indices + return res.values - def forward(self, x): - return torch.mean(x, dim=self.axis, keepdim=self.keepdim) - def bound_backward(self, last_lA, last_uA, x): - for i in range(len(self.axis)): - if self.axis[i] < 0: - self.axis[i] = self.make_axis_non_negative(self.axis[i]) - assert self.axis[i] > 0 +class BoundReduceMean(BoundReduce): + def forward(self, *x): + x = self._parse_input_and_axis(*x) + return torch.mean(x, dim=self.axis, keepdim=self.keepdim) + def bound_backward(self, last_lA, last_uA, *args, **kwargs): def _bound_oneside(last_A): if last_A is None: return None @@ -77,10 +96,11 @@ def _bound_oneside(last_A): last_A = last_A.expand(*shape) / size_axis return last_A - return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0 + return self._return_bound_backward(_bound_oneside(last_lA), + _bound_oneside(last_uA)) - def bound_forward(self, dim_in, x): - assert (self.keepdim) + def bound_forward(self, dim_in, x, *args): + assert self.keepdim assert (len(self.axis) == 1) axis = self.make_axis_non_negative(self.axis[0]) assert (axis > 0) @@ -91,25 +111,16 @@ def bound_forward(self, dim_in, x): ub = x.ub.sum(dim=axis, keepdim=True) / size return LinearBound(lw, lb, uw, ub) -class BoundReduceSum(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - self.axis = attr['axes'] if 'axes' in attr else None - self.keepdim = bool(attr['keepdims']) - self.use_default_ibp = True - def forward(self, x): +class BoundReduceSum(BoundReduce): + def forward(self, *x): + x = self._parse_input_and_axis(*x) if self.axis is not None: return torch.sum(x, dim=self.axis, keepdim=self.keepdim) else: return torch.sum(x) - def bound_backward(self, last_lA, last_uA, x): - for i in range(len(self.axis)): - if self.axis[i] < 0: - self.axis[i] = len(self.input_shape) + self.axis[i] - assert self.axis[i] > 0 - + def bound_backward(self, last_lA, last_uA, x, *args, **kwargs): def _bound_oneside(last_A): if last_A is None: return None @@ -124,9 +135,10 @@ def _bound_oneside(last_A): last_A = last_A.expand(*shape) return last_A - return [(_bound_oneside(last_lA), _bound_oneside(last_uA))], 0, 0 + return self._return_bound_backward(_bound_oneside(last_lA), + _bound_oneside(last_uA)) - def bound_forward(self, dim_in, x): + def bound_forward(self, dim_in, x, *args): assert len(self.axis) == 1 axis = self.make_axis_non_negative(self.axis[0]) assert axis > 0 diff --git a/auto_LiRPA/operators/relu.py b/auto_LiRPA/operators/relu.py new file mode 100644 index 0000000..6cb7f80 --- /dev/null +++ b/auto_LiRPA/operators/relu.py @@ -0,0 +1,931 @@ +"""BoundRelu.""" +from typing import Optional, Tuple +import torch +from torch import Tensor +from collections import OrderedDict +from .base import * +from .clampmult import multiply_by_A_signs +from .activation_base import BoundActivation, BoundOptimizableActivation +from .gradient_modules import ReLUGrad +from .solver_utils import grb +from ..utils import unravel_index, prod + + +class BoundTwoPieceLinear(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.options = options + self.ibp_intermediate = True + self.splittable = True + self.use_sparse_spec_alpha = options.get('sparse_spec_alpha', False) + self.use_sparse_features_alpha = options.get('sparse_features_alpha', False) + self.alpha_lookup_idx = self.alpha_indices = None + self.beta = self.masked_beta = self.sparse_betas = None + self.split_beta_used = False + self.history_beta_used = False + self.flattened_nodes = None + self.patch_size = {} + self.cut_used = False + self.cut_module = None + + def init_opt_parameters(self, start_nodes): + ref = self.inputs[0].lower # a reference variable for getting the shape + batch_size = ref.size(0) + self.alpha = OrderedDict() + self.alpha_lookup_idx = OrderedDict() # For alpha with sparse spec dimention. + self.alpha_indices = None # indices of non-zero alphas. + verbosity = self.options.get('verbosity', 0) + + # Alpha can be sparse in both spec dimension, and the C*H*W dimension. + # We first deal with the sparse-feature alpha, which is sparse in the + # C*H*W dimesnion of this layer. + minimum_sparsity = self.options.get('minimum_sparsity', 0.9) + if (self.use_sparse_features_alpha + and hasattr(self.inputs[0], 'lower') + and hasattr(self.inputs[0], 'upper')): + # Pre-activation bounds available, we will store the alpha for unstable neurons only. + # Since each element in a batch can have different unstable neurons, + # for simplicity we find a super-set using any(dim=0). + # This can be non-ideal if the x in a batch are very different. + self.get_unstable_idx() + total_neuron_size = self.inputs[0].lower.numel() // batch_size + if self.alpha_indices[0].size(0) <= minimum_sparsity * total_neuron_size: + # Shape is the number of unstable neurons in this layer. + alpha_shape = [self.alpha_indices[0].size(0)] + # Skip the batch, spec dimension, and find the lower slopes for all unstable neurons. + if len(self.alpha_indices) == 1: + # This layer is after a linear layer. + alpha_init = self.init_d[:, :, self.alpha_indices[0]] + elif len(self.alpha_indices) == 3: + # This layer is after a conv2d layer. + alpha_init = self.init_d[ + :, :, self.alpha_indices[0], self.alpha_indices[1], + self.alpha_indices[2]] + elif len(self.alpha_indices) == 2: + # This layer is after a conv1d layer. + alpha_init = self.init_d[ + :, :, self.alpha_indices[0], self.alpha_indices[1]] + else: + raise ValueError + if verbosity > 0: + print(f'layer {self.name} using sparse-features alpha with shape {alpha_shape}; unstable size ' + f'{self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({list(ref.shape)})') + else: + alpha_shape = self.shape # Full alpha. + alpha_init = self.init_d + if verbosity > 0: + print(f'layer {self.name} using full alpha with shape {alpha_shape}; unstable size ' + f'{self.alpha_indices[0].size(0)}; total size {total_neuron_size} ({list(ref.shape)})') + self.alpha_indices = None # Use full alpha. + else: + alpha_shape = self.shape # Full alpha. + alpha_init = self.init_d + # Now we start to create alphas for all start nodes. + # When sparse-spec feature is enabled, alpha is created for only + # unstable neurons in start node. + for start_node in start_nodes: + ns, output_shape, unstable_idx = start_node[:3] + if isinstance(output_shape, (list, tuple)): + if len(output_shape) > 1: + size_s = prod(output_shape) # Conv layers. + else: + size_s = output_shape[0] + else: + size_s = output_shape + # unstable_idx may be a tensor (dense layer or conv layer + # with shared alpha), or tuple of 3-d tensors (conv layer with + # non-sharing alpha). + sparsity = float('inf') if unstable_idx is None else unstable_idx.size(0) if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].size(0) + if sparsity <= minimum_sparsity * size_s and self.use_sparse_spec_alpha: + # For fully connected layer, or conv layer with shared alpha per channel. + # shape is (2, sparse_spec, batch, this_layer_shape) + # We create sparse specification dimension, where the spec dimension of alpha only includes slopes for unstable neurons in start_node. + self.alpha[ns] = torch.empty([self.alpha_size, sparsity + 1, batch_size, *alpha_shape], + dtype=torch.float, device=ref.device, requires_grad=True) + self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, sparse_spec) dimensions. + if verbosity > 0: + print(f'layer {self.name} start_node {ns} using sparse-spec alpha {list(self.alpha[ns].size())}' + f' with unstable size {sparsity} total_size {size_s} output_shape {output_shape}') + # unstable_idx is a list of used neurons (or channels for BoundConv) for the start_node. + assert unstable_idx.ndim == 1 if isinstance(unstable_idx, torch.Tensor) else unstable_idx[0].ndim == 1 + # We only need to the alpha for the unstable neurons in start_node. + indices = torch.arange(1, sparsity + 1, device=alpha_init.device, dtype=torch.long) + if isinstance(output_shape, int) or len(output_shape) == 1: + # Fully connected layers, or conv layer in patches mode with partially shared alpha (pixels in the same channel use the same alpha). + self.alpha_lookup_idx[ns] = torch.zeros(size_s, dtype=torch.long, device=alpha_init.device) + # This lookup table maps the unstable_idx to the actual alpha location in self.alpha[ns]. + # Note that self.alpha[ns][:,0] is reserved for any unstable neurons that are not found in the lookup table. This usually should not + # happen, unless reference bounds are not properly set. + self.alpha_lookup_idx[ns].data[unstable_idx] = indices + else: + # conv layer in matrix mode, or in patches mode but with non-shared alpha. The lookup table is 3-d. + assert len(output_shape) == 3 + self.alpha_lookup_idx[ns] = torch.zeros(output_shape, dtype=torch.long, device=alpha_init.device) + if isinstance(unstable_idx, torch.Tensor): + # Convert the unstable index from flattend 1-d to 3-d. (matrix mode). + unstable_idx_3d = unravel_index(unstable_idx, output_shape) + else: + # Patches mode with non-shared alpha, unstable_idx is already 3d. + unstable_idx_3d = unstable_idx + # Build look-up table. + self.alpha_lookup_idx[ns].data[unstable_idx_3d[0], unstable_idx_3d[1], unstable_idx_3d[2]] = indices + else: + # alpha shape is (2, spec, batch, this_layer_shape). "this_layer_shape" may still be sparse. + self.alpha[ns] = torch.empty([self.alpha_size, size_s, batch_size, *alpha_shape], + dtype=torch.float, device=ref.device, requires_grad=True) + self.alpha[ns].data.copy_(alpha_init.data) # This will broadcast to (2, spec) dimensions + if verbosity > 0: + print(f'layer {self.name} start_node {ns} using full alpha {list(self.alpha[ns].size())} with unstable ' + f'size {sparsity if unstable_idx is not None else None} total_size {size_s} output_shape {output_shape}') + # alpha_lookup_idx can be used for checking if sparse alpha is used or not. + self.alpha_lookup_idx[ns] = None + + def select_alpha_by_idx(self, last_lA, last_uA, unstable_idx, start_node, alpha_lookup_idx): + # Each alpha has shape (2, output_shape, batch_size, *relu_node_shape]. + # If slope is shared, output_shape will be 1. + # The *relu_node_shape might be sparse (sparse-feature alpha), where the non-zero values are indicated by self.alpha_indices. + # The out_shape might be sparse (sparse-spec alpha), where the non-zero values are indexed by self.alpha_lookup_idx. + if unstable_idx is not None: + # print(f'relu layer {self.name}, start_node {start_node}, unstable_idx {type(unstable_idx)} alpha idx {self.alpha_lookup_idx[start_node.name].size()}') + if self.alpha_lookup_idx is not None: + alpha_lookup_idx = self.alpha_lookup_idx[start_node.name] + else: + alpha_lookup_idx = None + if isinstance(unstable_idx, tuple): + # Start node is a conv node. + selected_alpha = self.alpha[start_node.name] + if isinstance(last_lA, Tensor) or isinstance(last_uA, Tensor): + # Start node is a conv node but we received tensors as A matrices. + # Patches mode converted to matrix, or matrix mode used. Need to select accross the spec dimension. + # For this node, since it is in matrix mode, the spec dimension is out_c * out_h * out_w + # Shape is [2, spec, batch, *this_layer_shape] + if alpha_lookup_idx is None: + if self.options['optimize_bound_args'].get('use_shared_alpha', False): + # alpha is shared, and its spec dimension is always 1. In this case we do not need to select. + # selected_alpha will have shape [2, 1, batch, *this_layer_shape] + pass + else: + # alpha is not shared, so it has shape [2, spec, batch, *this_layer_shape] + # Reshape the spec dimension to c*h*w so we can select used alphas based on unstable index. + # Shape becomes [2, out_c, out_h, out_w, batch, *this_layer_shape] + selected_alpha = selected_alpha.view(selected_alpha.size(0), *start_node.output_shape[1:], *selected_alpha.shape[2:]) + selected_alpha = selected_alpha[:, unstable_idx[0], unstable_idx[1], unstable_idx[2]] + else: + assert alpha_lookup_idx.ndim == 3 + # We only stored some alphas, and A is also sparse, so the unstable_idx must be first translated to real indices. + # alpha shape is (2, sparse_spec_shape, batch_size, *relu_node_shape) where relu_node_shape can also be sparse. + # We use sparse-spec alphas. Need to convert these unstable_idx[0], unstable_idx[1], unstable_idx[0] using lookup table. + _unstable_idx = alpha_lookup_idx[unstable_idx[0], unstable_idx[1], unstable_idx[2]] + selected_alpha = self.non_deter_index_select(selected_alpha, index=_unstable_idx, dim=1) + else: + # Patches mode. Alpha must be selected after unfolding, so cannot be done here. + # Selection is deferred to maybe_unfold() using alpha_lookup_idx. + # For partially shared alpha, its shape is (2, out_c, batch_size, *relu_node_shape). + # For full alpha, its shape is (2, out_c*out_h*out_w, batch_size, *relu_node_shape). + # Both the spec dimension and relu_node_shape dimensions can be sparse. + pass + elif unstable_idx.ndim == 1: + # Start node is a FC node. + # Only unstable neurons of the start_node neurons are used. + assert alpha_lookup_idx is None or alpha_lookup_idx.ndim == 1 + if self.options['optimize_bound_args'].get('use_shared_alpha', False): + # Shared alpha is used, all output specs use the same alpha. No selection is needed. + # The spec dim is 1 and will be broadcast. + selected_alpha = self.alpha[start_node.name] + else: + _unstable_idx = alpha_lookup_idx[unstable_idx] if alpha_lookup_idx is not None else unstable_idx + selected_alpha = self.non_deter_index_select(self.alpha[start_node.name], index=_unstable_idx, dim=1) + elif unstable_idx.ndim == 2: + assert alpha_lookup_idx is None, "sparse spec alpha has not been implemented yet." + # Each element in the batch selects different neurons. + selected_alpha = batched_index_select(self.alpha[start_node.name], index=unstable_idx, dim=1) + else: + raise ValueError + else: + # Spec dimension is dense. Alpha must not be created sparsely. + assert self.alpha_lookup_idx is None or self.alpha_lookup_idx[start_node.name] is None + selected_alpha = self.alpha[start_node.name] + return selected_alpha, alpha_lookup_idx + + def reconstruct_full_alpha(self, sparse_alpha, full_alpha_shape, alpha_indices): + full_alpha = torch.zeros(full_alpha_shape, dtype=sparse_alpha.dtype, device=sparse_alpha.device) + if len(alpha_indices) == 1: + # Relu after a dense layer. + full_alpha[:, :, alpha_indices[0]] = sparse_alpha + elif len(alpha_indices) == 3: + # Relu after a conv2d layer. + full_alpha[:, :, alpha_indices[0], alpha_indices[1], alpha_indices[2]] = sparse_alpha + elif len(alpha_indices) == 2: + # Relu after a conv1d layer. + full_alpha[:, :, alpha_indices[0], alpha_indices[1]] = sparse_alpha + else: + raise ValueError + return full_alpha + + def bound_backward(self, last_lA, last_uA, x=None, start_node=None, + unstable_idx=None, reduce_bias=True, **kwargs): + """ + start_node: the name of the layer where the backward bound propagation starts. + Can be the output layer or an intermediate layer. + unstable_idx: indices for the unstable neurons, whose bounds need to be computed. + Either be a tuple (for patches) or a 1-D tensor. + """ + # Usage of output constraints requires access to bounds of the previous iteration + # (see _clear_and_set_new) + apply_output_constraints_to = self.options["optimize_bound_args"]["apply_output_constraints_to"] + if hasattr(x, "lower"): + lower = x.lower + else: + assert start_node.are_output_constraints_activated_for_layer(apply_output_constraints_to) + lower = x.previous_iteration_lower + if hasattr(x, "upper"): + upper = x.upper + else: + assert start_node.are_output_constraints_activated_for_layer(apply_output_constraints_to) + upper = x.previous_iteration_upper + # Get element-wise CROWN linear relaxations. + (upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, + lb_upper_d, ub_upper_d, alpha_lookup_idx) = \ + self._backward_relaxation(last_lA, last_uA, x, start_node, unstable_idx) + # save for calculate babsr score + self.d = upper_d + self.lA = last_lA + # Save for initialization bounds. + self.init_d = lower_d + + # Choose upper or lower bounds based on the sign of last_A + def _bound_oneside(last_A, d_pos, d_neg, b_pos, b_neg): + if last_A is None: + return None, 0 + # Obtain the new linear relaxation coefficients based on the signs in last_A. + _A, _bias = multiply_by_A_signs( + last_A, d_pos, d_neg, b_pos, b_neg, reduce_bias=reduce_bias) + if isinstance(last_A, Patches): + # Save the patch size, which will be used in init_alpha() to determine the number of optimizable parameters. + A_prod = _A.patches + if start_node is not None: + if last_A.unstable_idx is not None: + # Sparse patches, we need to construct the full patch size: (out_c, batch, out_h, out_w, c, h, w). + self.patch_size[start_node.name] = [ + last_A.output_shape[1], A_prod.size(1), + last_A.output_shape[2], last_A.output_shape[3], + A_prod.size(-3), A_prod.size(-2), A_prod.size(-1)] + else: + # Regular patches. + self.patch_size[start_node.name] = A_prod.size() + return _A, _bias + + ######## A problem with patches mode for cut constraint start ########## + # There are cases that the node that is in the constraint but not selected by the patches for the output node + # trick: only count the small patches that have all the split node coeffs[ci].sum() equal to coeffs_unfolded[ci][out_h, out_w, -1].sum() + # we should force these beta to be 0 to disable the effect of these constraints + A = last_lA if last_lA is not None else last_uA + current_layer_shape = lower.size()[1:] + if self.cut_used and type(A) is Patches: + self.cut_module.patch_trick(start_node, self.name, A, current_layer_shape) + ######## A problem with patches mode for cut constraint end ########## + + if self.cut_used: + if self.leaky_alpha > 0: + raise NotImplementedError + # propagate postrelu node in cut constraints + last_lA, last_uA = self.cut_module.relu_cut( + start_node, self.name, last_lA, last_uA, current_layer_shape, + unstable_idx, batch_mask=self.inputs[0].alpha_beta_update_mask) + + # In patches mode we might need an unfold. + # lower_d, upper_d, lower_b, upper_b: 1, batch, current_c, current_w, current_h or None + upper_d = maybe_unfold_patches(upper_d, last_lA if last_lA is not None else last_uA) + lower_d = maybe_unfold_patches(lower_d, last_lA if last_lA is not None else last_uA) + upper_b = maybe_unfold_patches(upper_b, last_lA if last_lA is not None else last_uA) + lower_b = maybe_unfold_patches(lower_b, last_lA if last_lA is not None else last_uA) # for ReLU it is always None; keeping it here for completeness. + # ub_lower_d and lb_lower_d might have sparse spec dimension, so they may need alpha_lookup_idx to convert to actual spec dim. + ub_lower_d = maybe_unfold_patches(ub_lower_d, last_uA, alpha_lookup_idx=alpha_lookup_idx) + ub_upper_d = maybe_unfold_patches(ub_upper_d, last_uA, alpha_lookup_idx=alpha_lookup_idx) + # optimizable slope lb_lower_d: spec (only channels in spec layer), batch, current_c, current_w, current_h + # patches mode lb_lower_d after unfold: unstable, batch, in_C, H, W + lb_lower_d = maybe_unfold_patches(lb_lower_d, last_lA, alpha_lookup_idx=alpha_lookup_idx) + lb_upper_d = maybe_unfold_patches(lb_upper_d, last_lA, alpha_lookup_idx=alpha_lookup_idx) + + if self.cut_used: + assert reduce_bias + I = (lower < 0) * (upper > 0) + # propagate integer var of relu neuron (arelu) in cut constraints through relu layer + lA, uA, lbias, ubias = self.cut_module.arelu_cut( + start_node, self.name, last_lA, last_uA, lower_d, upper_d, + lower_b, upper_b, lb_lower_d, ub_lower_d, I, x, self.patch_size, + current_layer_shape, unstable_idx, + batch_mask=self.inputs[0].alpha_beta_update_mask) + else: + uA, ubias = _bound_oneside( + last_uA, ub_upper_d if upper_d is None else upper_d, + ub_lower_d if lower_d is None else lower_d, upper_b, lower_b) + lA, lbias = _bound_oneside( + last_lA, lb_lower_d if lower_d is None else lower_d, + lb_upper_d if upper_d is None else upper_d, lower_b, upper_b) + + if self.cut_used: + # propagate prerelu node in cut constraints + lA, uA = self.cut_module.pre_cut( + start_node, self.name, lA, uA, current_layer_shape, unstable_idx, + batch_mask=self.inputs[0].alpha_beta_update_mask) + self.masked_beta_lower = self.masked_beta_upper = None + + return [(lA, uA)], lbias, ubias + + def dump_optimized_params(self): + ret = {'alpha': self.alpha} + if self.use_sparse_spec_alpha: + ret['alpha_lookup_idx'] = self.alpha_lookup_idx + if self.use_sparse_features_alpha: + ret['alpha_indices'] = self.alpha_indices + return ret + + def restore_optimized_params(self, alpha): + self.alpha = alpha['alpha'] + if self.use_sparse_spec_alpha: + self.alpha_lookup_idx = alpha['alpha_lookup_idx'] + if self.use_sparse_features_alpha: + self.alpha_indices = alpha['alpha_indices'] + + +class BoundRelu(BoundTwoPieceLinear): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.relu_options = options.get('relu', 'adaptive') # FIXME: use better names. + self.leaky_alpha = attr.get('alpha', 0) + self.alpha_size = 2 + # Alpha dimension is (2, output_shape, batch, *shape) for ReLU. + + def get_unstable_idx(self): + self.alpha_indices = torch.logical_and( + self.inputs[0].lower < 0, self.inputs[0].upper > 0).any(dim=0).nonzero(as_tuple=True) + + def clip_alpha(self): + for v in self.alpha.values(): + v.data = torch.clamp(v.data, self.leaky_alpha, 1.) + + def forward(self, x): + self.shape = x.shape[1:] + if self.flattened_nodes is None: + self.flattened_nodes = x[0].reshape(-1).shape[0] + if self.leaky_alpha > 0: + return F.leaky_relu(x, negative_slope=self.leaky_alpha) + else: + return F.relu(x) + + def _relu_lower_bound_init(self, upper_k): + """Return the initial lower bound without relaxation.""" + if self.relu_options == "same-slope": + # the same slope for upper and lower + lower_k = upper_k + elif self.relu_options == "zero-lb": + # Always use slope 0 as lower bound. Any value between 0 and 1 is a valid lower bound for CROWN + lower_k = torch.zeros_like(upper_k) + lower_k = (upper_k >= 1.0).to(upper_k) + if self.leaky_alpha > 0: + lower_k += (upper_k < 1.0).to(upper_k) * self.leaky_alpha + elif self.relu_options == "one-lb": + # Always use slope 1 as lower bound + lower_k = ((upper_k > self.leaky_alpha).to(upper_k) + + (upper_k <= self.leaky_alpha).to(upper_k) + * self.leaky_alpha) + else: + # adaptive + if self.leaky_alpha == 0: + lower_k = (upper_k > 0.5).to(upper_k) + else: + # FIXME this may not be optimal for leaky relu + lower_k = ((upper_k > 0.5).to(upper_k) + + (upper_k <= 0.5).to(upper_k) * self.leaky_alpha) + return lower_k + + def _forward_relaxation(self, x): + self._init_masks(x) + self.mask_pos = self.mask_pos.to(x.lower) + self.mask_both = self.mask_both.to(x.lower) + + upper_k, upper_b = self._relu_upper_bound( + x.lower, x.upper, self.leaky_alpha) + self.uw = self.mask_pos + self.mask_both * upper_k + self.ub = self.mask_both * upper_b + + if self.opt_stage in ['opt', 'reuse']: + # Each actual alpha in the forward mode has shape (batch_size, *relu_node_shape]. + # But self.alpha has shape (2, output_shape, batch_size, *relu_node_shape] + # and we do not need its first two dimensions. + lower_k = self.alpha['_forward'][0, 0] + else: + lower_k = self._relu_lower_bound_init(upper_k) + + # NOTE #FIXME Saved for initialization bounds for optimization. + # In the backward mode, same-slope bounds are used. + # But here it is using adaptive bounds which seem to be better + # for nn4sys benchmark with loose input bounds. Need confirmation + # for other cases. + self.lower_d = lower_k.detach() # saved for initializing optimized bounds + + self.lw = self.mask_both * lower_k + self.mask_pos + + def bound_dynamic_forward(self, x, max_dim=None, offset=0): + if self.leaky_alpha > 0: + raise NotImplementedError + + self._init_masks(x) + self.mask_pos = self.mask_pos.to(x.lower) + self.mask_both = self.mask_both.to(x.lower) + + upper_k, upper_b = self._relu_upper_bound( + x.lower, x.upper, self.leaky_alpha) + w_new = (self.mask_pos.unsqueeze(1) * x.lw + + self.mask_both.unsqueeze(1) * upper_k.unsqueeze(1) * x.lw) + upper_b = self.mask_both * upper_b / 2 + b_new = (self.mask_pos * x.lb + + self.mask_both * upper_k * x.lb + upper_b) + + # Create new variables for unstable ReLU + batch_size = w_new.shape[0] + device = w_new.device + unstable = self.mask_both.view(batch_size, -1) + tot_unstable = int(unstable.sum(dim=-1).max()) + tot_dim = x.tot_dim + tot_unstable + + if offset + w_new.shape[1] < x.tot_dim: + return LinearBound( + w_new, b_new, w_new, b_new, x_L=x.x_L, x_U=x.x_U, tot_dim=tot_dim) + + index = torch.cumsum(unstable, dim=-1).to(torch.int64) + index = (index - (offset + w_new.shape[1] - x.tot_dim)).clamp(min=0) + num_new_dim = int(index.max()) + num_new_dim_actual = min(num_new_dim, max_dim - w_new.shape[1]) + index = index.clamp(max=num_new_dim_actual+1) + w_unstable = torch.zeros(batch_size, num_new_dim_actual + 2, unstable.size(-1), device=device) + x_L_unstable = -torch.ones(batch_size, num_new_dim_actual, device=device) + x_U_unstable = torch.ones(batch_size, num_new_dim_actual, device=device) + w_unstable.scatter_(dim=1, index=index.unsqueeze(1), src=upper_b.view(batch_size, 1, -1), reduce='add') + w_unstable = w_unstable[:, 1:-1].view(batch_size, num_new_dim_actual, *w_new.shape[2:]) + + w_new = torch.cat([w_new, w_unstable], dim=1) + x_L_new = torch.cat([x.x_L, x_L_unstable], dim=-1) + x_U_new = torch.cat([x.x_U, x_U_unstable], dim=-1) + + return LinearBound( + w_new, b_new, w_new, b_new, x_L=x_L_new, x_U=x_U_new, tot_dim=tot_dim) + + def bound_forward(self, dim_in, x): + self._forward_relaxation(x) + lb = self.lw * x.lb + ub = self.uw * x.ub + self.ub + lw = (self.lw.unsqueeze(1) * x.lw) if x.lw is not None else None + uw = (self.uw.unsqueeze(1) * x.uw) if x.uw is not None else None + if not lw.requires_grad: + del self.mask_both, self.mask_pos + del self.lw, self.uw, self.ub + return LinearBound(lw, lb, uw, ub) + + @staticmethod + @torch.jit.script + def _relu_upper_bound(lb, ub, leaky_alpha: float): + """Upper bound slope and intercept according to CROWN relaxation.""" + lb_r = lb.clamp(max=0) + ub_r = ub.clamp(min=0) + ub_r = torch.max(ub_r, lb_r + 1e-8) + if leaky_alpha > 0: + upper_d = (ub_r - leaky_alpha * lb_r) / (ub_r - lb_r) + upper_b = - lb_r * upper_d + leaky_alpha * lb_r + else: + upper_d = ub_r / (ub_r - lb_r) + upper_b = - lb_r * upper_d + return upper_d, upper_b + + @staticmethod + def _relu_mask_alpha(lower, upper, lb_lower_d : Optional[Tensor], + ub_lower_d : Optional[Tensor], leaky_alpha : float = 0, + ) -> Tuple[Optional[Tensor], Optional[Tensor], Tensor]: + lower_mask = (lower >= 0).requires_grad_(False).to(lower.dtype) + upper_mask = (upper <= 0).requires_grad_(False) + if leaky_alpha > 0: + zero_coeffs = False + else: + zero_coeffs = upper_mask.all() + no_mask = (1. - lower_mask) * (1. - upper_mask.to(upper.dtype)) + if lb_lower_d is not None: + lb_lower_d = ( + torch.clamp(lb_lower_d, min=leaky_alpha, max=1.) * no_mask + + lower_mask) + if leaky_alpha > 0: + lb_lower_d += upper_mask * leaky_alpha + if ub_lower_d is not None: + ub_lower_d = ( + torch.clamp(ub_lower_d, min=leaky_alpha, max=1.) * no_mask + + lower_mask) + if leaky_alpha > 0: + ub_lower_d += upper_mask * leaky_alpha + return lb_lower_d, ub_lower_d, zero_coeffs + + def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx): + # Usage of output constraints requires access to bounds of the previous iteration + # (see _clear_and_set_new) + if x is not None: + apply_output_constraints_to = self.options['optimize_bound_args']['apply_output_constraints_to'] + if hasattr(x, "lower"): + lower = x.lower + else: + assert start_node.are_output_constraints_activated_for_layer(apply_output_constraints_to) + lower = x.previous_iteration_lower + if hasattr(x, "upper"): + upper = x.upper + else: + assert start_node.are_output_constraints_activated_for_layer(apply_output_constraints_to) + upper = x.previous_iteration_upper + else: + lower = self.lower + upper = self.upper + + # Upper bound slope and intercept according to CROWN relaxation. + upper_d, upper_b = self._relu_upper_bound(lower, upper, self.leaky_alpha) + + flag_expand = False + ub_lower_d = lb_lower_d = None + lower_b = None # ReLU does not have lower bound intercept (=0). + alpha_lookup_idx = None # For sparse-spec alpha. + if self.opt_stage in ['opt', 'reuse']: + # Alpha-CROWN. + lower_d = None + selected_alpha, alpha_lookup_idx = self.select_alpha_by_idx(last_lA, last_uA, + unstable_idx, start_node, alpha_lookup_idx) + # The first dimension is lower/upper intermediate bound. + if last_lA is not None: + lb_lower_d = selected_alpha[0] + if last_uA is not None: + ub_lower_d = selected_alpha[1] + + if self.alpha_indices is not None: + # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only. + # Recover to full alpha first. + sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape + full_alpha_shape = sparse_alpha_shape[:-1] + self.shape + if lb_lower_d is not None: + lb_lower_d = self.reconstruct_full_alpha( + lb_lower_d, full_alpha_shape, self.alpha_indices) + if ub_lower_d is not None: + ub_lower_d = self.reconstruct_full_alpha( + ub_lower_d, full_alpha_shape, self.alpha_indices) + + lb_lower_d, ub_lower_d, zero_coeffs = self._relu_mask_alpha(lower, upper, lb_lower_d, ub_lower_d) + self.zero_backward_coeffs_l = self.zero_backward_coeffs_u = zero_coeffs + flag_expand = True # we already have the spec dimension. + else: + # FIXME: the shape can be incorrect if unstable_idx is not None. + # This will cause problem if some ReLU layers are optimized, some are not. + lower_d = self._relu_lower_bound_init(upper_d) + + # Upper bound always needs an extra specification dimension, since they only depend on lb and ub. + upper_d = upper_d.unsqueeze(0) + upper_b = upper_b.unsqueeze(0) + if not flag_expand: + if self.opt_stage in ['opt', 'reuse']: + # We have different slopes for lower and upper bounds propagation. + lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None + ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None + else: + lower_d = lower_d.unsqueeze(0) + return (upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, + None, None, alpha_lookup_idx) + + def interval_propagate(self, *v): + h_L, h_U = v[0][0], v[0][1] + return self.forward(h_L), self.forward(h_U) + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + if self.leaky_alpha > 0: + raise NotImplementedError + + # e.g., last layer input gurobi vars (8,16,16) + gvars_array = np.array(v[0]) + this_layer_shape = gvars_array.shape + assert gvars_array.shape == self.output_shape[1:] + + pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1) + pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1) + + new_layer_gurobi_vars = [] + relu_integer_vars = [] + new_relu_layer_constrs = [] + # predefined zero variable shared in the whole solver model + zero_var = model.getVarByName("zero") + + for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)): + pre_ub = pre_ubs[neuron_idx] + pre_lb = pre_lbs[neuron_idx] + + if pre_lb >= 0: + # ReLU is always passing + var = pre_var + elif pre_ub <= 0: + var = zero_var + else: + ub = pre_ub + + var = model.addVar(ub=ub, lb=pre_lb, + obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'ReLU{self.name}_{neuron_idx}') + + if model_type == "mip" or model_type == "lp_integer": + # binary indicator + if model_type == "mip": + a = model.addVar(vtype=grb.GRB.BINARY, name=f'aReLU{self.name}_{neuron_idx}') + elif model_type == "lp_integer": + a = model.addVar(ub=1, lb=0, vtype=grb.GRB.CONTINUOUS, name=f'aReLU{self.name}_{neuron_idx}') + relu_integer_vars.append(a) + + new_relu_layer_constrs.append( + model.addConstr(pre_var - pre_lb * (1 - a) >= var, + name=f'ReLU{self.name}_{neuron_idx}_a_0')) + new_relu_layer_constrs.append( + model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1')) + new_relu_layer_constrs.append( + model.addConstr(pre_ub * a >= var, name=f'ReLU{self.name}_{neuron_idx}_a_2')) + new_relu_layer_constrs.append( + model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_3')) + + elif model_type == "lp": + new_relu_layer_constrs.append( + model.addConstr(var >= 0, name=f'ReLU{self.name}_{neuron_idx}_a_0')) + new_relu_layer_constrs.append( + model.addConstr(var >= pre_var, name=f'ReLU{self.name}_{neuron_idx}_a_1')) + new_relu_layer_constrs.append(model.addConstr( + pre_ub * pre_var - (pre_ub - pre_lb) * var >= pre_ub * pre_lb, + name=f'ReLU{self.name}_{neuron_idx}_a_2')) + + else: + print(f"gurobi model type {model_type} not supported!") + + new_layer_gurobi_vars.append(var) + + new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist() + if model_type in ["mip", "lp_integer"]: + self.integer_vars = relu_integer_vars + self.solver_vars = new_layer_gurobi_vars + self.solver_constrs = new_relu_layer_constrs + model.update() + + def build_gradient_node(self, grad_upstream): + if self.leaky_alpha > 0: + raise NotImplementedError + node_grad = ReLUGrad() + grad_input = (grad_upstream, self.inputs[0].forward_value) + # An extra node is needed to consider the state of ReLU activation + grad_extra_nodes = [self.inputs[0]] + return node_grad, grad_input, grad_extra_nodes + + def get_split_mask(self, lower, upper, input_index): + assert input_index == 0 + return torch.logical_and(lower < 0, upper > 0) + + +class BoundLeakyRelu(BoundRelu): + pass + + +class BoundSign(BoundActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.splittable = True + + def forward(self, x): + return torch.sign(x) + + def bound_relax(self, x, init=False): + if init: + self.init_linear_relaxation(x) + mask_0 = torch.logical_and(x.lower == 0, x.upper == 0) + mask_pos_0 = torch.logical_and(x.lower == 0, x.upper > 0) + mask_neg_0 = torch.logical_and(x.lower < 0, x.upper == 0) + mask_pos = x.lower > 0 + mask_neg = x.upper < 0 + mask_both = torch.logical_not(torch.logical_or(torch.logical_or( + mask_0, torch.logical_or(mask_pos, mask_pos_0)), + torch.logical_or(mask_neg, mask_neg_0))) + self.add_linear_relaxation(mask=mask_0, type='lower', + k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=0) + self.add_linear_relaxation(mask=mask_0, type='upper', + k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=0) + + self.add_linear_relaxation(mask=mask_pos_0, type='lower', + k=1/x.upper.clamp(min=1e-8), x0=torch.zeros_like(x.upper), y0=0) + self.add_linear_relaxation(mask=torch.logical_or(mask_pos_0, mask_pos), type='upper', + k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=1) + + self.add_linear_relaxation(mask=torch.logical_or(mask_neg_0, mask_neg), type='lower', + k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=-1) + self.add_linear_relaxation(mask=mask_neg_0, type='upper', + k=-1/x.lower.clamp(max=-1e-8), x0=torch.zeros_like(x.upper), y0=0) + + self.add_linear_relaxation(mask=mask_pos, type='lower', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=1) + self.add_linear_relaxation(mask=mask_neg, type='upper', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=-1) + self.add_linear_relaxation(mask=mask_both, type='lower', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=-1) + self.add_linear_relaxation(mask=mask_both, type='upper', k=0, x0=torch.zeros_like(x.upper, requires_grad=True), y0=1) + + +class SignMergeFunction_loose(torch.autograd.Function): + # Modified SignMerge operator. + # Change its backward function so that the "gradient" can be used for pgd attack + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = torch.sign(torch.sign(input) + 1e-1) + return output + + @staticmethod + def backward(ctx, grad_output): + eps = 5 # should be carefully chosen + input, = ctx.saved_tensors + grad_input = grad_output.clone() + grad_input[abs(input) >= eps] = 0 + grad_input /= eps + return grad_input + +class SignMergeFunction_tight(torch.autograd.Function): + # Modified SignMerge operator. + # Change its backward function so that the "gradient" can be used for pgd attack + @staticmethod + def forward(ctx, input): + ctx.save_for_backward(input) + output = torch.sign(torch.sign(input) + 1e-1) + return output + + @staticmethod + def backward(ctx, grad_output): + eps = 0.1 # should be carefully chosen + input, = ctx.saved_tensors + grad_input = grad_output.clone() + grad_input[abs(input) >= eps] = 0 + grad_input /= eps + return grad_input + + +class BoundSignMerge(BoundTwoPieceLinear): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.alpha_size = 4 + self.loose_function = SignMergeFunction_loose + self.tight_function = SignMergeFunction_tight + self.signmergefunction = self.tight_function # default + + def get_unstable_idx(self): + self.alpha_indices = torch.logical_and( + self.inputs[0].lower < 0, self.inputs[0].upper >= 0).any(dim=0).nonzero(as_tuple=True) + + def forward(self, x): + self.shape = x.shape[1:] + return self.signmergefunction.apply(x) + + def _mask_alpha(self, lower, upper, lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d): + lower_mask = (lower >= 0.).requires_grad_(False).to(lower.dtype) + upper_mask = (upper < 0.).requires_grad_(False).to(upper.dtype) + no_mask = 1. - (lower_mask + upper_mask) + if lb_lower_d is not None: + lb_lower_d = torch.min(lb_lower_d, 2/upper.clamp(min=1e-8)) + lb_lower_d = torch.clamp(lb_lower_d, min=0) * no_mask + lb_upper_d = torch.min(lb_upper_d, -2/lower.clamp(max=-1e-8)) + lb_upper_d = torch.clamp(lb_upper_d, min=0) * no_mask + if ub_lower_d is not None: + ub_lower_d = torch.min(ub_lower_d, 2/upper.clamp(min=1e-8)) + ub_lower_d = torch.clamp(ub_lower_d, min=0) * no_mask + ub_upper_d = torch.min(ub_upper_d, -2/lower.clamp(max=-1e-8)) + ub_upper_d = torch.clamp(ub_upper_d, min=0) * no_mask + return lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d + + def _backward_relaxation(self, last_lA, last_uA, x, start_node, unstable_idx): + if x is not None: + lower, upper = x.lower, x.upper + else: + lower, upper = self.lower, self.upper + + flag_expand = False + ub_lower_d = lb_lower_d = lb_upper_d = ub_upper_d = None + alpha_lookup_idx = None # For sparse-spec alpha. + if self.opt_stage in ['opt', 'reuse']: + # Alpha-CROWN. + upper_d = lower_d = None + selected_alpha, alpha_lookup_idx = self.select_alpha_by_idx(last_lA, last_uA, + unstable_idx, start_node, alpha_lookup_idx) + # The first dimension is lower/upper intermediate bound. + if last_lA is not None: + lb_lower_d = selected_alpha[0] + lb_upper_d = selected_alpha[2] + if last_uA is not None: + ub_lower_d = selected_alpha[1] + ub_upper_d = selected_alpha[3] + + if self.alpha_indices is not None: + # Sparse alpha on the hwc dimension. We store slopes for unstable neurons in this layer only. + # Recover to full alpha first. + sparse_alpha_shape = lb_lower_d.shape if lb_lower_d is not None else ub_lower_d.shape + full_alpha_shape = sparse_alpha_shape[:-1] + self.shape + if lb_lower_d is not None: + lb_lower_d = self.reconstruct_full_alpha( + lb_lower_d, full_alpha_shape, self.alpha_indices) + lb_upper_d = self.reconstruct_full_alpha( + lb_upper_d, full_alpha_shape, self.alpha_indices) + if ub_lower_d is not None: + ub_lower_d = self.reconstruct_full_alpha( + ub_lower_d, full_alpha_shape, self.alpha_indices) + ub_upper_d = self.reconstruct_full_alpha( + ub_upper_d, full_alpha_shape, self.alpha_indices) + + # condition only on the masked part + if self.inputs[0].alpha_beta_update_mask is not None: + update_mask = self.inputs[0].alpha_beta_update_mask + if lb_lower_d is not None: + lb_lower_d_new = lb_lower_d[:, update_mask] + lb_upper_d_new = lb_upper_d[:, update_mask] + else: + lb_lower_d_new = lb_upper_d_new = None + if ub_lower_d is not None: + ub_lower_d_new = ub_lower_d[:, update_mask] + ub_upper_d_new = ub_upper_d[:, update_mask] + else: + ub_lower_d_new = ub_upper_d_new = None + lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d = self._mask_alpha(lower, upper, + lb_lower_d_new, ub_lower_d_new, lb_upper_d_new, ub_upper_d_new) + else: + lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d = self._mask_alpha(lower, upper, + lb_lower_d, ub_lower_d, lb_upper_d, ub_upper_d) + flag_expand = True # we already have the spec dimension. + else: + lower_d = torch.zeros_like(upper, requires_grad=True) + upper_d = torch.zeros_like(upper, requires_grad=True) + + mask_pos = (x.lower >= 0.).requires_grad_(False).to(x.lower.dtype) + mask_neg = (x.upper < 0.).requires_grad_(False).to(x.upper.dtype) + lower_b = (-1 * (1 - mask_pos) + mask_pos).unsqueeze(0) + upper_b = (-1 * mask_neg + (1 - mask_neg)).unsqueeze(0) + + # Upper bound always needs an extra specification dimension, since they only depend on lb and ub. + if not flag_expand: + if self.opt_stage in ['opt', 'reuse']: + # We have different slopes for lower and upper bounds propagation. + lb_lower_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None + ub_lower_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None + lb_upper_d = lb_lower_d.unsqueeze(0) if last_lA is not None else None + ub_upper_d = ub_lower_d.unsqueeze(0) if last_uA is not None else None + else: + lower_d = lower_d.unsqueeze(0) + upper_d = upper_d.unsqueeze(0) + return (upper_d, upper_b, lower_d, lower_b, lb_lower_d, ub_lower_d, + lb_upper_d, ub_upper_d, alpha_lookup_idx) + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + + # e.g., last layer input gurobi vars (8,16,16) + gvars_array = np.array(v[0]) + this_layer_shape = gvars_array.shape + assert gvars_array.shape == self.output_shape[1:] + + pre_lbs = self.inputs[0].lower.cpu().detach().numpy().reshape(-1) + pre_ubs = self.inputs[0].upper.cpu().detach().numpy().reshape(-1) + + new_layer_gurobi_vars = [] + integer_vars = [] + layer_constrs = [] + # predefined zero variable shared in the whole solver model + one_var = model.getVarByName("one") + neg_one_var = model.getVarByName("neg_one") + + for neuron_idx, pre_var in enumerate(gvars_array.reshape(-1)): + pre_ub = pre_ubs[neuron_idx] + pre_lb = pre_lbs[neuron_idx] + + if pre_lb >= 0: + var = one_var + elif pre_ub < 0: + var = neg_one_var + else: + ub = pre_ub + + var = model.addVar(ub=ub, lb=pre_lb, + obj=0, + vtype=grb.GRB.CONTINUOUS, + name=f'Sign{self.name}_{neuron_idx}') + + a = model.addVar(vtype=grb.GRB.BINARY, name=f'aSign{self.name}_{neuron_idx}') + integer_vars.append(a) + + layer_constrs.append( + model.addConstr(pre_lb * a <= pre_var, name=f'Sign{self.name}_{neuron_idx}_a_0')) + layer_constrs.append( + model.addConstr(pre_ub * (1 - a) >= pre_var, name=f'Sign{self.name}_{neuron_idx}_a_1')) + layer_constrs.append( + model.addConstr(var == 1 - 2*a, name=f'Sign{self.name}_{neuron_idx}_a_2')) + + new_layer_gurobi_vars.append(var) + + new_layer_gurobi_vars = np.array(new_layer_gurobi_vars).reshape(this_layer_shape).tolist() + if model_type in ["mip", "lp_integer"]: + self.integer_vars = integer_vars + self.solver_vars = new_layer_gurobi_vars + self.solver_constrs = layer_constrs + model.update() \ No newline at end of file diff --git a/auto_LiRPA/operators/resize.py b/auto_LiRPA/operators/resize.py new file mode 100644 index 0000000..e809517 --- /dev/null +++ b/auto_LiRPA/operators/resize.py @@ -0,0 +1,303 @@ +""" Resize operator """ +import itertools + +import torch + +from .base import * +import numpy as np +from .solver_utils import grb +from ..patches import unify_shape, create_valid_mask, is_shape_used +from .gradient_modules import Conv2dGrad + + +class BoundResize(Bound): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + # only support nearest mode for now + assert attr["mode"] == "nearest" + self.mode = attr["mode"] + self.scale_factor = None + + def forward(self, x, size=None, scale_factor=None): + # currently, forwarding size is not supported. + assert isinstance(size, torch.Tensor) and len(size.tolist()) == 0 + # currently, only support enlarge tensor size by an integer factor. + assert len(scale_factor.tolist()) == 4 and np.array([tmp.is_integer() and tmp > 0 for tmp in scale_factor.tolist()]).all() + assert (scale_factor[0:2].to(torch.long) == 1).all(), 'only support resize on the H and W dim' + self.scale_factor = tuple([int(tmp) for tmp in scale_factor][2:]) + if x.ndim == 4: + final = F.interpolate( + x, None, self.scale_factor, mode=self.mode) + else: + raise NotImplementedError( + "Interpolation in 3D or interpolation with parameter size has not been implmented.") + return final + + def interval_propagate(self, *v): + l, u = zip(*v) + return Interval.make_interval(self.forward(*l), self.forward(*u), v[0]) + + def bound_forward(self, dim_in, *inp): + x = inp[0] + lw, lb, uw, ub = x.lw, x.lb, x.uw, x.ub + new_lw, new_lb, new_uw, new_ub = \ + torch.nn.functional.upsample(lw, scale_factor=([1] * (lw.ndim - 4)) + list(self.scale_factor), mode=self.mode), \ + torch.nn.functional.upsample(lb, scale_factor=([1] * (lb.ndim - 4)) + list(self.scale_factor), mode=self.mode), \ + torch.nn.functional.upsample(uw, scale_factor=([1] * (uw.ndim - 4)) + list(self.scale_factor), mode=self.mode), \ + torch.nn.functional.upsample(ub, scale_factor=([1] * (ub.ndim - 4)) + list(self.scale_factor), mode=self.mode) + return LinearBound( + lw = new_lw, + lb = new_lb, + uw = new_uw, + ub = new_ub) + + def bound_backward(self, last_lA, last_uA, *x, **kwargs): + + def _bound_oneside(last_A): + if last_A is None: + return None + assert type(last_A) is Patches or last_A.ndim == 5 + # in case the kernel size cannot be divided by scale_factor, we round up the shape + split_shape = tuple((torch.tensor( + last_A.shape)[-2:] / torch.tensor(self.scale_factor)).ceil().to(torch.long).tolist()) + new_shape = last_A.shape[:-2] + split_shape + if not type(last_A) is Patches: + # classical mode is simple to handle by + # sum the grid elements by using avg_pool2d with divisor_override=1 + return torch.nn.functional.avg_pool2d( + last_A.reshape(-1, *last_A.shape[-2:]), kernel_size=self.scale_factor, stride=self.scale_factor, + divisor_override=1).reshape(new_shape) + else: + # for patches mode + assert type(last_A) is Patches + assert self.scale_factor[0] == self.scale_factor[1] + if self.scale_factor[0] == 1: + # identity upsampling + return last_A + if isinstance(last_A.padding, int) and last_A.padding % self.scale_factor[0] == 0 and last_A.stride % self.scale_factor[0] == 0 and last_A.inserted_zeros == 0: + # an easy case where patch sliding windows coincides with the nearest sampling scaling windows + # in this case, we divide each patch to size of scale_factor sub-matrices, + # and sum up each sub-matrices respectively + # print(last_A.shape) + padding = last_A.shape[-1] % self.scale_factor[-1] + new_patches = torch.nn.functional.pad(last_A.patches, (0, padding, 0, padding)) + new_patches = torch.nn.functional.avg_pool2d( + new_patches.reshape(-1, *new_patches.shape[-2:]), kernel_size=self.scale_factor, + stride=self.scale_factor, divisor_override=1).reshape(new_shape) + return last_A.create_similar(patches=new_patches, + stride=last_A.stride//self.scale_factor[0], + padding=last_A.padding//self.scale_factor[0], + ) + else: + """ + The following part is created and mainly maintained by Linyi + Time complexity = O(A.numel * scale_factor + outH * kerH + outW * kerW + A.numel * kerH * kerW) + With Python loop complexity = O(outH + outW + kerH * kerW * scale_factor^2) + """ + # preparation: unify shape + if last_A.padding: + padding = unify_shape(last_A.padding) + else: + padding = (0,0,0,0) + # padding = (left, right, top, bottom) + if last_A.output_padding: + output_padding = unify_shape(last_A.output_padding) + else: + output_padding = (0,0,0,0) + # output_padding = (left, right, top, bottom) + + """ + Step 0: filter out valid entries that maps to real cells of input + Like with inserted zeros = 2, [x 0 0 x 0 0 x]. Only "x" cells are kept + Borrowed from one_d generation from Conv patches + """ + one_d_unfolded_r = create_valid_mask(self.output_shape, + last_A.patches.device, + last_A.patches.dtype, + last_A.patches.shape[-2:], + last_A.stride, + last_A.inserted_zeros, + last_A.padding, + last_A.output_padding, + last_A.unstable_idx) + patches = last_A.patches * one_d_unfolded_r + + """ + Step 1: compute the coordinate mapping from patch coordinates to input coordinates + Time complexity: O(outH + outW) + note: last_A shape is [outC, batch, outH, outW, inC, kerH, kerW] + We create H_idx_map and W_idx_map of shape [outH] and [outW] respectively, + recording the start idx of row/column for patches at position [.,.,.,.,.,i,j] + in H_idx_map[i] and W_idx_map[j] + """ + ker_size_h, ker_size_w = last_A.shape[-2], last_A.shape[-1] + if last_A.unstable_idx is None: + # we can get the real output H and W from shape[2] and shape [3] + out_h, out_w = last_A.shape[2], last_A.shape[3] + else: + # it seems to be stored in output_shape + out_h, out_w = last_A.output_shape[-2], last_A.output_shape[-1] + h_idx_map = torch.arange(0, out_h) * last_A.stride - padding[-2] + output_padding[-2] * last_A.stride + h_idx_map = h_idx_map.to(last_A.device) + w_idx_map = torch.arange(0, out_w) * last_A.stride - padding[-4] + output_padding[-4] * last_A.stride + w_idx_map = w_idx_map.to(last_A.device) + + r""" + Step 2: compute the compressed patches + Time complexity: O(outH * kerH + outW * kerW + A.numel * kerH * kerW) + Upsampling needs to sum up A cells in scale_factor * scale_factor sub-blocks + Example: when scale factor is 2 + [ a b c d + e f g h ---\ [ a+b+e+f c+d+g+h + i j k l ---/ i+j+m+n k+l+o+p] + m n o p] + In patches mode, we need to sum up cells in each patch accordingly. + The summing mechanism could change at different locations. + For each spatial dimension, we create a binary sum_mask tensor [outH, ker_size_h, new_ker_size_h] + to select the cells to sum up + Example: + For [a b c d] -> [a+b c+d], with 3x3 patch covering [0..2] and [2..4]. + The first patch needs to sum to [a+b c]; the second patch needs to sum to [b c+d] + So we have sum_mask + [ for patch 1: [[1, 1, 0], (first entry sums up index 0 and 1) + [0, 0, 1]]^T, (second entry sums up index 2) + for patch 2: [[1, 0, 0], (first entry sums up index 0) + [0, 1, 1]]^T (second entry sums up index 1 and 2) + ] + With the mask, we can now compute the new patches with einsum: + [outC, batch, outH, outW, inC, kerH, kerW] * [outH, kerH, new_kerH] -> [outC, batch, outH, outW, inC, new_kerH, kerW] + """ + tot_scale_fac = ((last_A.inserted_zeros + 1) * self.scale_factor[0], (last_A.inserted_zeros + 1) * self.scale_factor[1]) + new_ker_size_h, new_ker_size_w = \ + (tot_scale_fac[0] + ker_size_h - 2) // tot_scale_fac[0] + 1, \ + (tot_scale_fac[1] + ker_size_w - 2) // tot_scale_fac[1] + 1 + + min_h_idx, max_h_idx = h_idx_map[0], h_idx_map[-1] + ker_size_h + shrank_h_idx = (torch.arange(min_h_idx, max_h_idx) + last_A.inserted_zeros).div(tot_scale_fac[0], rounding_mode='floor') + if last_A.unstable_idx is None: + # with nonsparse index, create full-sized sum musk for rows + ker_h_indexer = torch.arange(0, ker_size_h).to(last_A.device) + sum_mask_h = torch.zeros(last_A.shape[2], ker_size_h, new_ker_size_h).to(last_A.device) + for i in range(last_A.shape[2]): + sum_mask_h[i, ker_h_indexer, \ + shrank_h_idx[h_idx_map[i] - min_h_idx: h_idx_map[i] - min_h_idx + ker_size_h] - shrank_h_idx[h_idx_map[i] - min_h_idx]] = 1 + # set zero to those in padding area + padding_place_mask = (ker_h_indexer + h_idx_map[i] < 0) + sum_mask_h[i, padding_place_mask] = 0 + else: + # with sparse index, create sparse sum musk + sum_mask_h = torch.zeros(last_A.shape[0], ker_size_h, new_ker_size_h).to(last_A.device) + + row_nos = last_A.unstable_idx[1] + unstable_loc_indexer = torch.arange(0, row_nos.shape[0]).to(last_A.device) + + for k in range(ker_size_h): + place_in_new_ker = shrank_h_idx[h_idx_map[row_nos] - min_h_idx + k] - shrank_h_idx[h_idx_map[row_nos] - min_h_idx] + sum_mask_h[unstable_loc_indexer, k, place_in_new_ker] = 1 + # set zero to those in padding area + padding_place_mask = (h_idx_map[row_nos] + k < 0) + sum_mask_h[padding_place_mask, k] = 0 + + min_w_idx, max_w_idx = w_idx_map[0], w_idx_map[-1] + ker_size_w + shrank_w_idx = (torch.arange(min_w_idx, max_w_idx) + last_A.inserted_zeros).div(tot_scale_fac[1], rounding_mode='floor') + if last_A.unstable_idx is None: + # with nonsparse index, create full-sized sum musk for columns + ker_w_indexer = torch.arange(0, ker_size_w).to(last_A.device) + sum_mask_w = torch.zeros(last_A.shape[3], ker_size_w, new_ker_size_w).to(last_A.device) + for i in range(last_A.shape[3]): + sum_mask_w[i, ker_w_indexer, \ + shrank_w_idx[w_idx_map[i] - min_w_idx: w_idx_map[i] - min_w_idx + ker_size_w] - shrank_w_idx[w_idx_map[i] - min_w_idx]] = 1 + # set zero to those in padding area + padding_place_mask = (ker_w_indexer + w_idx_map[i] < 0) + sum_mask_w[i, padding_place_mask] = 0 + else: + # with sparse index, create sparse sum musk + sum_mask_w = torch.zeros(last_A.shape[0], ker_size_w, new_ker_size_w).to(last_A.device) + + col_nos = last_A.unstable_idx[2] + unstable_loc_indexer = torch.arange(0, col_nos.shape[0]).to(last_A.device) + + for k in range(ker_size_w): + place_in_new_ker = shrank_w_idx[w_idx_map[col_nos] - min_w_idx + k] - shrank_w_idx[w_idx_map[col_nos] - min_w_idx] + sum_mask_w[unstable_loc_indexer, k, place_in_new_ker] = 1 + # set zero to those in padding area + padding_place_mask = (w_idx_map[col_nos] + k < 0) + sum_mask_w[padding_place_mask, k] = 0 + + if last_A.unstable_idx is None: + # nonsparse aggregation + new_patches = torch.einsum("ObhwIij,hix,wjy->ObhwIxy", patches, sum_mask_h, sum_mask_w) + else: + # sparse aggregation + new_patches = torch.einsum("NbIij,Nix,Njy->NbIxy", patches, sum_mask_h, sum_mask_w) + + """ + Step 3: broadcasting the new_patches by repeating elements, + since later we would need to apply insert_zeros + For example, scale_factor = 3, repeat patch [a,b] to [a,a,a,b,b,b] + Time complexity: O(A.numel * scale_factor) + """ + ext_new_ker_size_h, ext_new_ker_size_w = \ + new_ker_size_h * tot_scale_fac[0], new_ker_size_w * tot_scale_fac[1] + ext_new_patches = torch.zeros(list(new_patches.shape[:-2]) + + [ext_new_ker_size_h, ext_new_ker_size_w], device=new_patches.device) + for i in range(ext_new_ker_size_h): + for j in range(ext_new_ker_size_w): + ext_new_patches[..., i, j] = new_patches[..., i // tot_scale_fac[0], j // tot_scale_fac[1]] + + """ + Step 4: compute new padding, stride, shape, insert_zeros, and output_padding + """ + # stride should be the same after upsampling, stride is an integer + # new_stride = last_A.stride + # padding can change much, the beginning should extend by (scale - 1) entries, + # the ending should extend by (ext_new_ker_size - ker_size) entries + # padding = (left, right, top, bottom) + new_padding = (padding[0] + (self.scale_factor[1] - 1) * (last_A.inserted_zeros + 1), + padding[1] + ext_new_ker_size_w - ker_size_w, + padding[2] + (self.scale_factor[0] - 1) * (last_A.inserted_zeros + 1), + padding[3] + ext_new_ker_size_h - ker_size_h) + if new_padding[0] == new_padding[1] and new_padding[1] == new_padding[2] and new_padding[2] == new_padding[3]: + # simplify to an int + new_padding = new_padding[0] + # only support uniform scaling on H and W now, i.e., self.scale_factor[0] == self.scale_factor[1] + inserted_zeros = tot_scale_fac[0] - 1 + # output padding seems not to change + # new_output_padding = last_A.output_padding + + """ + Package and create + """ + # sparse tensor doesn't support einsum which is necessary for subsequent computes, so deprecated + # if inserted_zeros >= 3: + # # mask unused cells + # input_shape = list(self.output_shape) + # input_shape[-2], input_shape[-1] = input_shape[-2] // self.scale_factor[-2], \ + # input_shape[-1] // self.scale_factor[-1] + # one_unfolded = create_valid_mask(input_shape, ext_new_patches.device, + # ext_new_patches.dtype, ext_new_patches.shape[-2:], + # last_A.stride, inserted_zeros, new_padding, + # last_A.output_padding, + # last_A.unstable_idx if last_A.unstable_idx else None) + # ext_new_patches = (ext_new_patches * one_unfolded).to_sparse() + + # print the shape change after upsampling, if needed + # print(f'After upsampling, ' + # f'{last_A.patches.shape} (pad={padding}, iz={last_A.inserted_zeros}, s={last_A.stride}) -> ' + # f'{ext_new_patches.shape} (pad={new_padding}, iz={inserted_zeros}, s={last_A.stride})') + ret_patches_A = last_A.create_similar(patches=ext_new_patches, + padding=new_padding, + inserted_zeros=inserted_zeros) + if self.input_shape[-2] < ret_patches_A.shape[-2] and self.input_shape[-1] < ret_patches_A.shape[-2] \ + and not is_shape_used(ret_patches_A.output_padding): + # using matrix mode could be more memory efficient + ret_matrix_A = ret_patches_A.to_matrix(self.input_shape) + # print(f'After upsampling, to_matrix: {ret_matrix_A.shape}') + ret_matrix_A = ret_matrix_A.transpose(0, 1) + return ret_matrix_A + else: + return ret_patches_A + + last_lA = _bound_oneside(last_lA) + last_uA = _bound_oneside(last_uA) + return [(last_lA, last_uA), (None, None), (None, None)], 0, 0 diff --git a/auto_LiRPA/operators/rnn.py b/auto_LiRPA/operators/rnn.py index dc634ae..a30dbbd 100644 --- a/auto_LiRPA/operators/rnn.py +++ b/auto_LiRPA/operators/rnn.py @@ -3,7 +3,7 @@ class BoundRNN(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.complex = True self.output_index = output_index diff --git a/auto_LiRPA/operators/shape.py b/auto_LiRPA/operators/shape.py index 1473baa..7650904 100644 --- a/auto_LiRPA/operators/shape.py +++ b/auto_LiRPA/operators/shape.py @@ -3,10 +3,11 @@ from ..patches import Patches, patches_to_matrix from .linear import BoundLinear from .gradient_modules import ReshapeGrad +from .constant import BoundConstant class BoundReshape(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) # It can be set to `view`, so that `view` instead of `reshape` will be used. self.option = options.get('reshape', 'reshape') @@ -22,7 +23,7 @@ def forward(self, x, shape): else: return x.reshape(shape) - def bound_backward(self, last_lA, last_uA, x, shape): + def bound_backward(self, last_lA, last_uA, x, shape, **kwargs): def _bound_oneside(A): if A is None: return None @@ -69,6 +70,11 @@ def bound_forward(self, dim_in, x, shape): ub = x.ub.reshape(batch_size, *self.shape[1:]) return LinearBound(lw, lb, uw, ub) + def bound_dynamic_forward(self, x, shape, max_dim=None, offset=0): + w = x.lw.reshape(x.lw.shape[0], x.lw.shape[1], *self.shape[1:]) + b = x.lb.reshape(x.lb.shape[0], *self.shape[1:]) + return LinearBound(w, b, w, b, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim) + def interval_propagate(self, *v): return Interval.make_interval( self.forward(v[0][0], v[1][0]), @@ -89,17 +95,31 @@ def build_gradient_node(self, grad_upstream): class BoundUnsqueeze(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) - self.axes = attr['axes'] - assert (len(self.axes) == 1) - self.axes = self.axes[0] + if 'axes' in attr: + self.axes = attr['axes'] + assert len(self.axes) == 1 + self.axes = self.axes[0] + else: + assert isinstance(inputs[1], BoundConstant) + assert inputs[1].value.numel() == 1 + self.axes = inputs[1].value[0] self.use_default_ibp = True - def forward(self, x): - return x.unsqueeze(self.axes) + def forward(self, *x): + return x[0].unsqueeze(self.axes) + + def bound_forward(self, dim_in, x): + assert self.axes > 0 + return LinearBound( + x.lw.unsqueeze(self.axes + 1), + x.lb.unsqueeze(self.axes), + x.uw.unsqueeze(self.axes + 1), + x.ub.unsqueeze(self.axies) + ) - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, *args, **kwargs): self.axes = self.make_axis_non_negative(self.axes, 'output') if self.axes == 0: # TODO: unsqueeze on batch dimension can be problematic. @@ -119,7 +139,7 @@ def bound_backward(self, last_lA, last_uA, x): uA = None return [(lA, uA)], 0, 0 - def bound_forward(self, dim_in, x): + def bound_forward(self, dim_in, x, *args): self.axes = self.make_axis_non_negative(self.axes, 'output') if len(self.input_shape) == 0: lw, lb = x.lw.unsqueeze(1), x.lb.unsqueeze(0) @@ -132,8 +152,9 @@ def bound_forward(self, dim_in, x): def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): self.solver_vars = self.forward(v[0]) + class BoundSqueeze(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.axes = attr['axes'] assert (len(self.axes) == 1) @@ -143,14 +164,26 @@ def __init__(self, attr, inputs, output_index, options): def forward(self, x): return x.squeeze(self.axes) - def bound_backward(self, last_lA, last_uA, x): - assert (self.axes != 0) + def bound_backward(self, last_lA, last_uA, x, **kwargs): + assert self.axes > 0 return [(last_lA.unsqueeze(self.axes + 1) if last_lA is not None else None, last_uA.unsqueeze(self.axes + 1) if last_uA is not None else None)], 0, 0 + def bound_forward(self, dim_in, x): + assert self.axes > 0 + return LinearBound( + x.lw.squeeze(self.axes + 1), + x.lb.squeeze(self.axes), + x.uw.squeeze(self.axes + 1), + x.ub.squeeze(self.axes) + ) + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(v[0]) + class BoundFlatten(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.use_default_ibp = True self.axis = attr['axis'] @@ -158,7 +191,7 @@ def __init__(self, attr, inputs, output_index, options): def forward(self, x): return torch.flatten(x, self.axis) - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): def _bound_oneside(A): if A is None: return None @@ -168,9 +201,7 @@ def _bound_oneside(A): def bound_dynamic_forward(self, x, max_dim=None, offset=0): w = torch.flatten(x.lw, self.axis + 1) b = torch.flatten(x.lb, self.axis) - x_L = torch.flatten(x.x_L, self.axis) - x_U = torch.flatten(x.x_U, self.axis) - return LinearBound(w, b, w, b, x_L=x_L, x_U=x_U, tot_dim=x.tot_dim) + return LinearBound(w, b, w, b, x_L=x.x_L, x_U=x.x_U, tot_dim=x.tot_dim) def bound_forward(self, dim_in, x): self.axis = self.make_axis_non_negative(self.axis) @@ -187,10 +218,14 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") self.solver_vars = np.array(v[0]).reshape(-1).tolist() model.update() + def build_gradient_node(self, grad_upstream): + node_grad = ReshapeGrad() + grad_input = (grad_upstream, self.inputs[0].forward_value) + return node_grad, grad_input, [] class BoundConcat(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.axis = attr['axis'] self.IBP_rets = None @@ -216,7 +251,7 @@ def interval_propagate(self, *v): eps = np.array(eps) # Supporting two cases: all inputs are Linf norm, or all inputs are L2 norm perturbed. # Some inputs can be constants without perturbations. - all_inf = all(map(lambda x: x is None or x == np.inf, norms)) + all_inf = all(map(lambda x: x is None or x == torch.inf, norms)) all_2 = all(map(lambda x: x is None or x == 2, norms)) h_L = [_v[0] for _v in v] @@ -235,7 +270,7 @@ def interval_propagate(self, *v): else: raise RuntimeError("BoundConcat does not support inputs with norm {}".format(norms)) - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): self.axis = self.make_axis_non_negative(self.axis, 'output') assert self.axis > 0 @@ -276,9 +311,9 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") BoundConcatFromSequence = BoundConcat class BoundShape(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) - self.use_default_ibp = True + self.never_perturbed = True @staticmethod def shape(x): @@ -318,11 +353,9 @@ def forward(self, x, indices): if indices == -1: indices = x.shape[self.axis] + indices return torch.index_select(x, dim=self.axis, index=indices).squeeze(self.axis) - elif self.axis == 0: - return torch.index_select( - x, dim=self.axis, index=indices.view(-1)).reshape( - *indices.shape, x.shape[-1]) - elif self.indices.ndim == 1: + elif indices.ndim == 1: + if self.axis == 0: + assert not self.perturbed # `index_select` requires `indices` to be a 1-D tensor return torch.index_select(x, dim=self.axis, index=indices) @@ -330,7 +363,7 @@ def forward(self, x, indices): f'data {x.shape}, indices {indices.shape}, ' f'axis {self.axis}') - def bound_backward(self, last_lA, last_uA, x, indices): + def bound_backward(self, last_lA, last_uA, *args, **kwargs): assert self.from_input def _expand_A_with_zeros(A, axis, idx, max_axis_size): @@ -365,7 +398,7 @@ def _bound_oneside(A): final_A = torch.zeros(*shape[:self.axis + 1], self.input_shape[self.axis], *shape[self.axis + 2:], device=A.device) idx = self.indices.view([*[1]*(self.axis+1), -1, *[1]*len(shape[self.axis + 2:])]) idx = idx.repeat([*A.shape[:self.axis+1], 1, *A.shape[self.axis+2:]]) - final_A.scatter_(dim=self.axis+1, index=idx, src=A) + final_A.scatter_add_(dim=self.axis+1, index=idx, src=A) return final_A elif isinstance(A, Patches): if self.indices.ndim == 0: @@ -384,7 +417,7 @@ def _bound_oneside(A): return [(_bound_oneside(last_lA), _bound_oneside(last_uA)), (None, None)], 0, 0 def bound_forward(self, dim_in, x, indices): - assert self.indices.ndim == 0 # TODO + assert self.indices.numel() == 1 and self.indices.ndim <= 1 # TODO if isinstance(x, torch.Size): lw = uw = torch.zeros(dim_in, device=self.device) @@ -393,10 +426,15 @@ def bound_forward(self, dim_in, x, indices): dim=self.axis, index=self.indices).squeeze(self.axis) else: axis = self.axis + 1 - lw = torch.index_select(x.lw, dim=self.axis + 1, index=self.indices).squeeze(axis) - uw = torch.index_select(x.uw, dim=self.axis + 1, index=self.indices).squeeze(axis) - lb = torch.index_select(x.lb, dim=self.axis, index=self.indices).squeeze(self.axis) - ub = torch.index_select(x.ub, dim=self.axis, index=self.indices).squeeze(self.axis) + lw = torch.index_select(x.lw, dim=self.axis + 1, index=self.indices) + uw = torch.index_select(x.uw, dim=self.axis + 1, index=self.indices) + lb = torch.index_select(x.lb, dim=self.axis, index=self.indices) + ub = torch.index_select(x.ub, dim=self.axis, index=self.indices) + if self.indices.ndim == 0: + lw = lw.squeeze(axis) + uw = uw.squeeze(axis) + lb = lb.squeeze(self.axis) + ub = ub.squeeze(self.axis) return LinearBound(lw, lb, uw, ub) def interval_propagate(self, *v): @@ -416,7 +454,7 @@ def forward(self, x, index): self.index = index return torch.gather(x, dim=self.axis, index=index) - def bound_backward(self, last_lA, last_uA, x, index): + def bound_backward(self, last_lA, last_uA, x, index, **kwargs): assert self.from_input dim = self._get_dim() @@ -455,7 +493,7 @@ def _get_dim(self): return dim class BoundTranspose(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.perm = attr['perm'] self.perm_inv_inc_one = [-1] * (len(self.perm) + 1) @@ -463,11 +501,12 @@ def __init__(self, attr, inputs, output_index, options): for i in range(len(self.perm)): self.perm_inv_inc_one[self.perm[i] + 1] = i + 1 self.use_default_ibp = True + self.ibp_intermediate = True def forward(self, x): return x.permute(*self.perm) - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): def _bound_oneside(last_A): if last_A is None: return None @@ -490,13 +529,24 @@ def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi") self.solver_vars = self.forward(*v) class BoundSlice(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.start = attr["starts"][0] if "starts" in attr else None self.end = attr["ends"][0] if "ends" in attr else None self.axes = attr["axes"][0] if "axes" in attr else None self.use_default_ibp = False + def __repr__(self): + attrs = {} + if (len(self.inputs) == 5 + and all(isinstance(item, BoundConstant) and item.value.numel() == 1 + for item in self.inputs[1:])): + attrs['start'] = self.inputs[1].value.item() + attrs['end'] = self.inputs[2].value.item() + attrs['axes'] = self.inputs[3].value.item() + attrs['step'] = self.inputs[4].value.item() + return super().__repr__(attrs) + def _fixup_params(self, shape, start, end, axes, steps): if start < 0: start += shape[axes] @@ -531,14 +581,15 @@ def interval_propagate(self, *v): def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): self.solver_vars = self.forward(*v) - def bound_backward(self, last_lA, last_uA, *x): + def bound_backward(self, last_lA, last_uA, *x, **kwargs): def _bound_oneside(A, start, end, axes, steps): if A is None: return None if isinstance(A, torch.Tensor): # Reuse the batch and spec dimension of A, and replace other shapes with input. A_shape = A.shape[:2] + self.input_shape[1:] - new_A = torch.zeros(size=A_shape, device=A.device, requires_grad=A.requires_grad) + new_A = torch.zeros(size=A_shape, device=A.device, + requires_grad=A.requires_grad) # Fill part of the new_A based on start, end, axes and steps. # Skip the spec dimension at the front (axes + 1). dim = axes if axes < 0 else axes + 1 @@ -550,7 +601,9 @@ def _bound_oneside(A, start, end, axes, steps): patches = A.patches # patches shape is [out_c, batch, out_h, out_w, in_c, patch_h, patch_w]. new_patches_shape = patches.shape[:4] + (self.input_shape[1], ) + patches.shape[-2:] - new_patches = torch.zeros(size=new_patches_shape, device=patches.device, requires_grad=patches.requires_grad) + new_patches = torch.zeros( + size=new_patches_shape, device=patches.device, + requires_grad=patches.requires_grad) indices = torch.arange(start, end, device=patches.device) new_patches = torch.index_copy(new_patches, dim=-3, index=indices, source=patches) # Only the in_c dimension is changed. @@ -593,9 +646,6 @@ def bound_forward(self, dim_in, *inputs): class BoundExpand(Bound): - def __init__(self, attr, inputs, output_index, options): - super().__init__(attr, inputs, output_index, options) - def forward(self, x, y): y = y.clone() assert y.ndim == 1 @@ -608,8 +658,9 @@ def forward(self, x, y): assert x.shape[i] == 1 or x.shape[i] == y[m - n + i] return x.expand(*list(y)) + class BoundSplit(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.axis = attr['axis'] self.split = attr['split'] @@ -620,7 +671,7 @@ def forward(self, x): self.axis = len(x.shape) - 1 return torch.split(x, self.split, dim=self.axis)[self.output_index] - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, **kwargs): assert self.axis > 0 pre = sum(self.split[:self.output_index]) suc = sum(self.split[(self.output_index + 1):]) @@ -649,3 +700,6 @@ def bound_forward(self, dim_in, x): lb = torch.split(x.lb, self.split, dim=self.axis)[self.output_index] ub = torch.split(x.ub, self.split, dim=self.axis)[self.output_index] return LinearBound(lw, lb, uw, ub) + + def build_solver(self, *v, model, C=None, model_type="mip", solver_pkg="gurobi"): + self.solver_vars = self.forward(v[0]) diff --git a/auto_LiRPA/operators/softmax.py b/auto_LiRPA/operators/softmax.py index cb91516..5804cbc 100644 --- a/auto_LiRPA/operators/softmax.py +++ b/auto_LiRPA/operators/softmax.py @@ -15,7 +15,7 @@ def forward(self, x): # The `option != 'complex'` case is not used in the auto_LiRPA main paper. class BoundSoftmax(Bound): - def __init__(self, attr, inputs, output_index, options): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): super().__init__(attr, inputs, output_index, options) self.axis = attr['axis'] self.option = options.get('softmax', 'complex') diff --git a/auto_LiRPA/operators/tanh.py b/auto_LiRPA/operators/tanh.py new file mode 100644 index 0000000..007a4e3 --- /dev/null +++ b/auto_LiRPA/operators/tanh.py @@ -0,0 +1,382 @@ +"""BoundTanh and similar ops.""" +import warnings +import torch +from .base import * +from .activation_base import BoundOptimizableActivation + + +def dtanh(x): + # to avoid bp error when cosh is too large + # cosh(25.0)**2 > 1e21 + mask = torch.lt(torch.abs(x), 25.0).to(x.dtype) + cosh = torch.cosh(mask * x + 1 - mask) + return mask * (1. / cosh.pow(2)) + +def dsigmoid(x): + return torch.sigmoid(x) * (1 - torch.sigmoid(x)) + +def darctan(x): + return (x.square() + 1.).reciprocal() + + +class BoundTanh(BoundOptimizableActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None, + activation=('tanh', torch.tanh, dtanh), precompute=True): + super().__init__(attr, inputs, output_index, options) + if options is None: + options = {} + self.splittable = True + self.ibp_intermediate = True + self.activation = activation + self.activation_name = activation[0] + self.activation_forward = activation[1] + self.activation_backward = activation[2] + if precompute: + self.precompute_relaxation(*activation) + # TODO make them configurable when implementing a general nonlinear activation. + # Neurons whose gap between pre-activation bounds is smaller than this + # threshold will be masked and don't need branching. + self.split_min_gap = 1e-2 #1e-4 + # Neurons whose pre-activation bounds don't overlap with this range + # are considered as stable (with values either 0 or 1) and don't need + # branching. + self.split_range = (-10, 10) + # The initialization will be adjusted if the pre-activation bounds are too loose. + self.loose_threshold = options.get('tanh', {}).get( + 'loose_threshold', None) + + def opt_init(self): + super().opt_init() + self.tp_both_lower_init = {} + self.tp_both_upper_init = {} + + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + shape = l.shape + # Alpha dimension is (8, output_shape, batch, *shape) for Tanh. + alpha = torch.empty(8, size_spec, *shape, device=l.device) + alpha.data[:4] = ((l + u) / 2) + alpha.data[4:6] = self.tp_both_lower_init[name_start] + alpha.data[6:8] = self.tp_both_upper_init[name_start] + return alpha + + @torch.no_grad() + def precompute_relaxation(self, name, func, dfunc, x_limit=500): + """ + This function precomputes the tangent lines that will be used as + lower/upper bounds for S-shapes functions. + """ + self.x_limit = x_limit + self.step_pre = 0.01 + self.num_points_pre = int(self.x_limit / self.step_pre) + max_iter = 100 + + logger.debug('Precomputing relaxation for %s (pre-activation limit: %f)', + name, x_limit) + + def check_lower(upper, d): + """Given two points upper, d (d <= upper), check if the slope at d will be less than f(upper) at upper.""" + k = dfunc(d) + # Return True if the slope is a lower bound. + return k * (upper - d) + func(d) <= func(upper) + + def check_upper(lower, d): + """Given two points lower, d (d >= lower), check if the slope at d will be greater than f(lower) at lower.""" + k = dfunc(d) + # Return True if the slope is a upper bound. + return k * (lower - d) + func(d) >= func(lower) + + # Given an upper bound point (>=0), find a line that is guaranteed to be a lower bound of this function. + upper = self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) + r = torch.zeros_like(upper) + # Initial guess, the tangent line is at -1. + l = -torch.ones_like(upper) + while True: + # Check if the tangent line at the guessed point is an lower bound at f(upper). + checked = check_lower(upper, l).int() + # If the initial guess is not smaller enough, then double it (-2, -4, etc). + l = checked * l + (1 - checked) * (l * 2) + if checked.sum() == l.numel(): + break + # Now we have starting point at l, its tangent line is guaranteed to be an lower bound at f(upper). + # We want to further tighten this bound by moving it closer to 0. + for _ in range(max_iter): + # Binary search. + m = (l + r) / 2 + checked = check_lower(upper, m).int() + l = checked * m + (1 - checked) * l + r = checked * r + (1 - checked) * m + # At upper, a line with slope l is guaranteed to lower bound the function. + self.d_lower = l.clone() + + # Do the same again: + # Given an lower bound point (<=0), find a line that is guaranteed to be an upper bound of this function. + lower = -self.step_pre * torch.arange(0, self.num_points_pre + 5, device=self.device) + l = torch.zeros_like(upper) + r = torch.ones_like(upper) + while True: + checked = check_upper(lower, r).int() + r = checked * r + (1 - checked) * (r * 2) + if checked.sum() == l.numel(): + break + for _ in range(max_iter): + m = (l + r) / 2 + checked = check_upper(lower, m).int() + l = (1 - checked) * m + checked * l + r = (1 - checked) * r + checked * m + self.d_upper = r.clone() + + logger.debug('Done') + + def forward(self, x): + return self.activation_forward(x) + + def bound_relax_impl(self, x, func, dfunc): + lower, upper = x.lower, x.upper + y_l, y_u = func(lower), func(upper) + # k_direct is the slope of the line directly connect (lower, func(lower)), (upper, func(upper)). + k_direct = k = (y_u - y_l) / (upper - lower).clamp(min=1e-8) + + # Fixed bounds that cannot be optimized. self.mask_neg are the masks for neurons with upper bound <= 0. + # Upper bound for the case of input lower bound <= 0, is always the direct line. + self.add_linear_relaxation( + mask=self.mask_neg, type='upper', k=k_direct, x0=lower, y0=y_l) + # Lower bound for the case of input upper bound >= 0, is always the direct line. + self.add_linear_relaxation( + mask=self.mask_pos, type='lower', k=k_direct, x0=lower, y0=y_l) + + # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed. + # Note that for neurons with also input lower bound >=0, they will be masked later. + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), + (upper / self.step_pre).to(torch.long).reshape(-1) + ) + 1 + if index.max() >= self.d_lower.numel(): + warnings.warn(f'Pre-activation bounds are too loose for {self}') + # Lookup the lower bound slope from the pre-computed table. + d_lower = torch.where( + (index < self.d_lower.numel()).view(lower.shape), + torch.index_select( + self.d_lower, 0, index.clamp(max=self.d_lower.numel() - 1) + ).view(lower.shape), + lower, + # If the pre-activation bounds are too loose, just use IBP. + # torch.ones_like(index).to(lower) * (-100.) + ).view(lower.shape) + else: + # Lookup the lower bound slope from the pre-computed table. + d_lower = torch.index_select( + self.d_lower, 0, index).view(lower.shape) + + # Indices of neurons with lower bound <=0, whose optimal slope to upper + # bound the function was pre-computed. + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + (lower / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + if index.max() >= self.d_upper.numel(): + warnings.warn(f'Pre-activation bounds are too loose for {self}') + # Lookup the lower bound slope from the pre-computed table. + d_upper = torch.where( + (index < self.d_upper.numel()).view(upper.shape), + torch.index_select( + self.d_upper, 0, index.clamp(max=self.d_upper.numel() - 1) + ).view(upper.shape), + upper, + ) + else: + d_upper = torch.index_select( + self.d_upper, 0, index).view(upper.shape) + + if self.opt_stage in ['opt', 'reuse']: + if not hasattr(self, 'alpha'): + # Raise an error if alpha is not created. + self._no_bound_parameters() + ns = self._start + + # Clipping is done here rather than after `opt.step()` call + # because it depends on pre-activation bounds + self.alpha[ns].data[0:2] = torch.max( + torch.min(self.alpha[ns][0:2], upper), lower) + self.alpha[ns].data[2:4] = torch.max( + torch.min(self.alpha[ns][2:4], upper), lower) + self.alpha[ns].data[4:6] = torch.min( + self.alpha[ns][4:6], d_lower) + self.alpha[ns].data[6:8] = torch.max( + self.alpha[ns][6:8], d_upper) + + # shape [2, out_c, n, c, h, w]. + tp_pos = self.alpha[ns][0:2] # For upper bound relaxation + tp_neg = self.alpha[ns][2:4] # For lower bound relaxation + tp_both_lower = self.alpha[ns][4:6] + tp_both_upper = self.alpha[ns][6:8] + + # No need to use tangent line, when the tangent point is at the left + # side of the preactivation lower bound. Simply connect the two sides. + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower)) + self.add_linear_relaxation( + mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), type='lower', + k=dfunc(tp_both_lower), x0=tp_both_lower) + + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper)) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), type='upper', + k=dfunc(tp_both_upper), x0=tp_both_upper) + + self.add_linear_relaxation( + mask=self.mask_neg, type='lower', k=dfunc(tp_neg), x0=tp_neg) + self.add_linear_relaxation( + mask=self.mask_pos, type='upper', k=dfunc(tp_pos), x0=tp_pos) + else: + if self.opt_stage == 'init': + # Initialize optimizable slope. + tp_both_lower_init = d_lower.detach() + tp_both_upper_init = d_upper.detach() + + if self.loose_threshold is not None: + # We will modify d_lower and d_upper inplace. + # So make a copy for these two. + tp_both_lower_init = tp_both_lower_init.clone() + tp_both_upper_init = tp_both_upper_init.clone() + # A different initialization if the pre-activation bounds + # are too loose + loose = torch.logical_or(lower < -self.loose_threshold, + upper > self.loose_threshold) + d_lower[loose] = lower[loose] + d_upper[loose] = upper[loose] + # tp_both_lower_init[loose] = lower[loose] + # tp_both_upper_init[loose] = upper[loose] + + ns = self._start + self.tp_both_lower_init[ns] = tp_both_lower_init + self.tp_both_upper_init[ns] = tp_both_upper_init + + # Not optimized (vanilla CROWN bound). + # Use the middle point slope as the lower/upper bound. Not optimized. + m = (lower + upper) / 2 + y_m = func(m) + k = dfunc(m) + # Lower bound is the middle point slope for the case input upper bound <= 0. + # Note that the upper bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation(mask=self.mask_neg, type='lower', k=k, x0=m, y0=y_m) + # Upper bound is the middle point slope for the case input lower bound >= 0. + # Note that the lower bound in this case is the direct line between (lower, func(lower)) and (upper, func(upper)). + self.add_linear_relaxation(mask=self.mask_pos, type='upper', k=k, x0=m, y0=y_m) + + # Now handle the case where input lower bound <=0 and upper bound >= 0. + # A tangent line starting at d_lower is guaranteed to be a lower bound given the input upper bound. + k = dfunc(d_lower) + # Another possibility is to use the direct line as the lower bound, when this direct line does not intersect with f. + # This is only valid when the slope at the input lower bound has a slope greater than the direct line. + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(lower)) + self.add_linear_relaxation(mask=mask_direct, type='lower', k=k_direct, x0=lower, y0=y_l) + # Otherwise we do not use the direct line, we use the d_lower slope. + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), + type='lower', k=k, x0=d_lower) + + # Do the same for the upper bound side when input lower bound <=0 and upper bound >= 0. + k = dfunc(d_upper) + mask_direct = torch.logical_and(self.mask_both, k_direct < dfunc(upper)) + self.add_linear_relaxation( + mask=mask_direct, type='upper', k=k_direct, x0=lower, y0=y_l) + self.add_linear_relaxation( + mask=torch.logical_xor(self.mask_both, mask_direct), + type='upper', k=k, x0=d_upper) + + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + self.bound_relax_impl( + x, self.activation_forward, self.activation_backward) + + def get_split_mask(self, lower, upper, input_index): + assert input_index == 0 + return torch.logical_and( + upper - lower >= self.split_min_gap, + torch.logical_or(upper >= self.split_range[0], + lower <= self.split_range[1]) + ) + + +class BoundSigmoid(BoundTanh): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options, + activation=('sigmoid', torch.sigmoid, dsigmoid)) + + +class BoundAtan(BoundTanh): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options, + activation=('arctan', torch.arctan, darctan)) + self.split_range = (-torch.inf, torch.inf) + + +class BoundTan(BoundAtan): + """ + The implementation of BoundTan is based on the S-shaped BoundAtan. We use the bounds from its + inverse function and directly convert the bounds of the inverse function to bounds of the original + function. This trick allows us to quickly implement bounds on inverse functions. + """ + + def forward(self, x): + return torch.tan(x) + + def _check_bounds(self, lower, upper): + # Lower and upper bounds must be within the same [-½π, ½π] region. + lower_periods = torch.floor((lower + 0.5 * torch.pi) / torch.pi) + upper_periods = torch.floor((upper + 0.5 * torch.pi) / torch.pi) + if not torch.allclose(lower_periods, upper_periods): + print('Tan preactivation lower bounds:\n', lower) + print('Tan preactivation upper bounds:\n', upper) + raise ValueError("BoundTan received pre-activation bounds that produce infinity. " + "The preactivation bounds are too loose. Try to reduce perturbation region.") + # Return the period number for each neuron. + # Period is 0 => bounds are within [-½π, ½π], + # Period is 1 => bounds are within [-½π + π, ½π + π] + # Period is -1 => bounds are within [-½π - π, ½π - π] + return lower_periods + + def _init_masks(self, x): + # The masks now must consider the periodicity. + lower = torch.remainder(x.lower + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi + upper = torch.remainder(x.upper + 0.5 * torch.pi, torch.pi) - 0.5 * torch.pi + self.mask_pos = lower >= 0 + self.mask_neg = upper <= 0 + self.mask_both = torch.logical_not(torch.logical_or(self.mask_pos, self.mask_neg)) + + def interval_propagate(self, *v): + # We need to check if the input lower and upper bounds are within the same period. + # Otherwise the bounds become infinity. + concrete_lower, concrete_upper = v[0][0], v[0][1] + self._check_bounds(concrete_lower, concrete_upper) + return super().interval_propagate(*v) + + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + periods = self._check_bounds(x.lower, x.upper) + periods = torch.pi * periods + # Create a fake x with inversed lower and upper. + inverse_x = lambda: None + inverse_x.lower = torch.tan(x.lower) + inverse_x.upper = torch.tan(x.upper) + super().bound_relax(inverse_x, init=init) + # Lower slope, lower bias, upper slope and upper bias are saved to + # self.lw, self.lb, self.uw, self.ub. We need to reverse them. + # E.g., y = self.lw * x + self.lb, now becomes x = 1./self.lw * y - self.lb / self.lw + # Additionally, we need to add the missing ½π periods. + new_upper_slope = 1. / self.lw + new_upper_bias = - self.lb / self.lw - periods / self.lw + new_lower_slope = 1. / self.uw + new_lower_bias = - self.ub / self.uw - periods / self.uw + self.lw = new_lower_slope + self.lb = new_lower_bias + self.uw = new_upper_slope + self.ub = new_upper_bias + diff --git a/auto_LiRPA/operators/trigonometric.py b/auto_LiRPA/operators/trigonometric.py new file mode 100644 index 0000000..451e3f6 --- /dev/null +++ b/auto_LiRPA/operators/trigonometric.py @@ -0,0 +1,438 @@ +import torch + +from .activation_base import BoundActivation +from .nonlinear import BoundOptimizableNonLinear + + +class BoundSin(BoundOptimizableNonLinear): + # Lookup tables shared by all BoundSin classes. + xl_lower_tb = None + xl_upper_tb = None + xu_lower_tb = None + xu_upper_tb = None + func, d_func = torch.sin, torch.cos + n_table_entries = 1001 + + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.ibp_intermediate = True + self.use_precompute = True + self.act_func = torch.sin + self.d_act_func = torch.cos + + # Bound limits used by IBP. + self.ibp_max_point = torch.pi / 2 + self.ibp_min_point = torch.pi * 3 / 2 + + self.all_table_x = torch.linspace( + 0, 2 * torch.pi, BoundSin.n_table_entries, device=self.device) + if self.use_precompute: + self.precompute_relaxation(self.act_func, self.d_act_func, x_limit = torch.pi / 2) + if BoundSin.xl_lower_tb is None: + # Generate look-up tables. + BoundSin.xl_lower_tb = BoundSin.get_lower_left_bound(self.all_table_x) + BoundSin.xl_upper_tb = BoundSin.get_upper_left_bound(self.all_table_x) + BoundSin.xu_lower_tb = BoundSin.get_lower_right_bound(self.all_table_x) + BoundSin.xu_upper_tb = BoundSin.get_upper_right_bound(self.all_table_x) + + def d2_act_func(self, x): + return -torch.sin(x) + + def _init_opt_parameters_impl(self, size_spec, name_start): + """Implementation of init_opt_parameters for each start_node.""" + l, u = self.inputs[0].lower, self.inputs[0].upper + shape = [size_spec] + list(l.shape) + alpha = torch.empty(12, *shape, device=l.device) + alpha.data[:4] = ((l + u) / 2).unsqueeze(0).expand(4, *shape) + alpha.data[4:6] = self.tp_both_lower_init[name_start].expand(2, *shape) + alpha.data[6:8] = self.tp_both_upper_init[name_start].expand(2, *shape) + alpha.data[8:10] = self.tp_left_lower_init[name_start].expand(2, *shape) + alpha.data[10:12] = self.tp_left_upper_init[name_start].expand(2, *shape) + return alpha + + def opt_init(self): + super().opt_init() + self.tp_both_lower_init = {} + self.tp_both_upper_init = {} + self.tp_left_lower_init = {} + self.tp_left_upper_init = {} + self.tp_right_lower_init = {} + self.tp_right_upper_init = {} + + def generate_inflections(self, lb, ub): + return + + def branch_input_domain(self, lb, ub): + lb_cycles = torch.floor((lb + 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi) + lb_clamped = lb - lb_cycles + ub_cycles = torch.floor((ub + 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi) + ub_clamped = ub - ub_cycles + + self.sigmoid_like_mask = (ub - lb <= torch.pi) + self.sigmoid_like_mask = torch.logical_and(self.sigmoid_like_mask, torch.logical_or( + torch.logical_and(lb_clamped <= 0.5 * torch.pi, ub_clamped <= 0.5 * torch.pi), + torch.logical_and(lb_clamped >= 0.5 * torch.pi, ub_clamped >= 0.5 * torch.pi))) + self.branch_mask = torch.logical_not(self.sigmoid_like_mask) + + self.mask_neg = torch.logical_and((self.d2_act_func(lb) >= 0), + torch.logical_and((self.d2_act_func(ub) >= 0), + self.sigmoid_like_mask)) + self.mask_pos = torch.logical_and((self.d2_act_func(lb) < 0), + torch.logical_and((self.d2_act_func(ub) < 0), + self.sigmoid_like_mask)) + self.mask_both = torch.logical_xor(self.sigmoid_like_mask, + torch.logical_or(self.mask_neg, self.mask_pos)) + + self.convex_concave = torch.logical_and(self.mask_both, + (self.d2_act_func(lb) >= 0)) + self.concave_convex = torch.logical_xor(self.mask_both, self.convex_concave) + + def generate_d_lower_upper(self, lower, upper): + # Indices of neurons with input upper bound >=0, whose optimal slope to lower bound the function was pre-computed. + # Note that for neurons with also input lower bound >=0, they will be masked later. + k_tensor = torch.floor(upper / (2 * torch.pi)) + upper_clamped = upper - k_tensor * (2 * torch.pi) + case1_mask = torch.logical_and(upper_clamped >= 0, upper_clamped <= torch.pi / 2) + upper_clamped_new = upper_clamped.clamp(min=0, max=(torch.pi / 2)) + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), + (upper_clamped_new / self.step_pre).to(torch.long).reshape(-1) + ) + 1 + # Lookup the lower bound slope from the pre-computed table. + d_lower = (torch.index_select(self.d_lower, 0, index).view(lower.shape) + k_tensor * 2 * torch.pi) * case1_mask + + case2_mask = torch.logical_and(upper_clamped >= torch.pi, upper_clamped <= 3 * torch.pi / 2) + upper_clamped_new = upper_clamped.clamp(min=torch.pi, max=(3 * torch.pi / 2)) + index = torch.max( + torch.zeros(upper.numel(), dtype=torch.long, device=upper.device), + ((torch.pi - upper_clamped_new) / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + # Lookup the lower bound slope from the pre-computed table. + d_upper = (torch.pi - torch.index_select(self.d_upper, 0, index).view(lower.shape) + k_tensor * 2 * torch.pi) * case2_mask + + # Indices of neurons with lower bound <=0, whose optimal slope to upper bound the function was pre-computed. + k_tensor = torch.floor(lower / (2 * torch.pi)) + lower_clamped = lower - k_tensor * (2 * torch.pi) + case3_mask = torch.logical_and(lower_clamped >= 3 * torch.pi / 2, lower_clamped <= 2 * torch.pi) + lower_clamped_new = lower_clamped.clamp(min=(3 * torch.pi / 2), max=2 * torch.pi) + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + ((lower_clamped_new - 2 * torch.pi) / -self.step_pre).to(torch.long).reshape(-1) + ) + 1 + d_upper += (torch.index_select(self.d_upper, 0, index).view(upper.shape) + (k_tensor + 1) * 2 * torch.pi) * case3_mask + + case4_mask = torch.logical_and(lower_clamped >= torch.pi / 2, lower_clamped <= torch.pi) + lower_clamped_new = lower_clamped.clamp(min=(torch.pi / 2), max=3 * torch.pi) + index = torch.max( + torch.zeros(lower.numel(), dtype=torch.long, device=lower.device), + ((torch.pi - lower_clamped_new) / self.step_pre).to(torch.long).reshape(-1) + ) + 1 + d_lower += (torch.pi - torch.index_select(self.d_lower, 0, index).view(upper.shape) + k_tensor * 2 * torch.pi) * case4_mask + return d_lower, d_upper + + @staticmethod + def n_crossing(start, end, s): + """Check how many times we will encounter value s + k*2*pi within start and end for any integer k.""" + cycles = torch.floor((end - start) / (2 * torch.pi)) # Number of 2pi cycles. + # Move s and end to the same 2 * pi cycle as start. + dist = torch.floor((s - start) / (2 * torch.pi)) + real_s = s - dist * 2 * torch.pi + real_end = end - cycles * 2 * torch.pi + return (real_s >= start).to(s) * (real_s <= real_end).to(s) + cycles + + @staticmethod + def arcsin(c): + """Arcsin with gradient fixes. + + arcsin(-1) and arcsin(1) have pathological gradients and should be avoided. + """ + if c.min() > -1 and c.max() < 1: + return torch.arcsin(c) + c_ = c.clone() + mask_neg = c == -1 + mask_pos = c == 1 + c_[mask_neg] = 0 + c_[mask_pos] = 0 + ret = torch.arcsin(c_) + ret[mask_neg] = -torch.pi / 2 + ret[mask_pos] = torch.pi / 2 + return ret + + @staticmethod + def get_intersection(start, end, c, theta=0.): + """Get the number of intersections between y = sin(x + theta) and y = c between start and end.""" + # Use arcsine to find the first 2 intersections. + crossing1 = BoundSin.arcsin(c) - theta + crossing2 = torch.pi - crossing1 - 2 * theta # Problematic at exact 1/2 pi, but ok in our case (happens only when lb=ub). + return BoundSin.n_crossing(start, end, crossing1) + BoundSin.n_crossing(start, end, crossing2) + + @staticmethod + def check_bound(tangent_point, x): + """Check whether the tangent line at tangent_point is a valid lower/upper bound for x.""" + # evaluate the value of the tangent line at x and see it is >= 0 or <=0. + d = BoundSin.d_func(tangent_point) + val = d * (x - tangent_point) + BoundSin.func(tangent_point) + # We want a positive margin when finding a lower line, but as close to 0 as possible. + # We want a negative margin when finding a upper line, but as close to 0 as possible. + margin = BoundSin.func(x) - val + return margin + + @staticmethod + @torch.no_grad() + def get_lower_left_bound(xl, steps=20): + """Get a global lower bound given lower bound on x. Return slope and intercept.""" + dtype = xl.dtype + # Constrain xl into the -0.5 pi to 1.5 pi region. + cycles = torch.floor((xl + 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi) + xl = xl - cycles + use_tangent_line = (xl >= torch.pi).to(dtype) + # Case 1: xl > pi, Lower tangent line is the only possible lower bound. + # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region. + left = torch.pi * torch.ones_like(xl) + # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1). + right = (1.5 * torch.pi) * torch.ones_like(xl) + last_right = right.clone() + for _ in range(steps): + mid = (left + right) / 2. + margin = BoundSin.check_bound(mid, xl) + pos_mask = (margin > 0).to(dtype) # We want to margin > 0 but at small as possible. + neg_mask = 1.0 - pos_mask + right = mid * pos_mask + right * neg_mask # We have positive margin, reduce right hand side. + last_right = mid * pos_mask + last_right * neg_mask # Always sound, since the margin is positive. + left = mid * neg_mask + left * pos_mask + d = xl * use_tangent_line + last_right * (1. - use_tangent_line) + # Return slope and bias. + return [d, cycles] + + @staticmethod + @torch.no_grad() + def get_upper_left_bound(xl, steps=20): + """Get a global upper bound given lower bound on x. Return slope and intercept.""" + dtype = xl.dtype + # Constrain xl into the -0.5 pi to 1.5 pi region. + cycles = torch.floor((xl - 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi) + xl = xl - cycles + use_tangent_line = (xl >= 2.0 * torch.pi).to(dtype) + # Case 1: xl > pi, Lower tangent line is the only possible lower bound. + # Case 2: Binary search needed. Testing from another tangent endpoint in [pi, 1.5*pi]. It must be in this region. + left = (2.0 * torch.pi) * torch.ones_like(xl) + # The right end guarantees the margin > 0 because it is basically a IBP lower bound (-1). + right = (2.5 * torch.pi) * torch.ones_like(xl) + last_right = right.clone() + for _ in range(steps): + mid = (left + right) / 2. + margin = BoundSin.check_bound(mid, xl) + pos_mask = (margin > 0).to(dtype) # We want to margin < 0 but at small as possible. + neg_mask = 1.0 - pos_mask + right = mid * neg_mask + right * pos_mask # We have positive margin, reduce right hand side. + last_right = mid * neg_mask + last_right * pos_mask # Always sound, since the margin is positive. + left = mid * pos_mask + left * neg_mask + d = xl * use_tangent_line + last_right * (1. - use_tangent_line) + # Return slope and bias. + return [d, cycles] + + @staticmethod + @torch.no_grad() + def get_lower_right_bound(xu, steps=20): + """Get a global lower bound given upper bound on x. Return slope and intercept.""" + # Constrain xu into the -0.5 pi to 1.5 pi region. + cycles = torch.floor((xu + 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi) + xu = xu - cycles + d, _ = BoundSin.get_lower_left_bound(torch.pi - xu, steps) + return [3 * torch.pi - d, cycles - 2 * torch.pi] + + @staticmethod + @torch.no_grad() + def get_upper_right_bound(xu, steps=20): + """Get a global upper bound given upper bound on x. Return slope and intercept.""" + # Constrain xu into the 0.5 pi to 2.5 pi region. + cycles = torch.floor((xu - 0.5 * torch.pi) / (2 * torch.pi)) * (2 * torch.pi) + xu = xu - cycles + d, _ = BoundSin.get_upper_left_bound(3 * torch.pi - xu, steps) + return [5 * torch.pi - d, cycles - 2 * torch.pi] + + def get_bound_tb(self, lb, ub): + """Find lower or upper bounds from lookup table.""" + step = 2 * torch.pi / (BoundSin.n_table_entries - 1) + # Move to 0 to 2 pi region. + lb_cycles = torch.floor(lb / (2 * torch.pi)) * (2 * torch.pi) + lb = torch.clamp(lb - lb_cycles, min=0, max=2 * torch.pi) + ub_cycles = torch.floor(ub / (2 * torch.pi)) * (2 * torch.pi) + ub = torch.clamp(ub - ub_cycles, min=0, max=2 * torch.pi) + # Find the indice within the lookup table from 0 - 2pi. + indices_lb = lb.div(step).long() + indices_ub = ub.div(step).long() + tangent_left_lower = BoundSin.xl_lower_tb[0][indices_lb] + tangent_left_upper = BoundSin.xl_upper_tb[0][indices_lb] + tangent_right_lower = BoundSin.xu_lower_tb[0][indices_ub] + tangent_right_upper = BoundSin.xu_upper_tb[0][indices_ub] + if self.opt_stage in ['opt', 'reuse']: + if not hasattr(self, 'alpha'): + # Raise an error if alpha is not created. + self._no_bound_parameters() + ns = self._start + + self.alpha[ns].data[8:10, :] = torch.min( + torch.max(self.alpha[ns][8:10, :], tangent_left_lower), tangent_right_lower) + self.alpha[ns].data[10:12, :] = torch.min( + torch.max(self.alpha[ns][10:12, :], tangent_left_upper), tangent_right_upper) + tangent_lower = self.alpha[ns][8:10, :] + tangent_upper = self.alpha[ns][10:12, :] + else: + tangent_lower = (tangent_left_lower + tangent_right_lower) / 2 + tangent_upper = (tangent_left_upper + tangent_right_upper) / 2 + if self.opt_stage == 'init': + ns = self._start + self.tp_left_lower_init[ns] = tangent_left_lower.detach() + self.tp_left_upper_init[ns] = tangent_left_upper.detach() + self.tp_right_lower_init[ns] = tangent_right_lower.detach() + self.tp_right_upper_init[ns] = tangent_right_upper.detach() + + d_lower = BoundSin.d_func(tangent_lower) + b_lower = BoundSin.func(tangent_lower) - d_lower * (tangent_lower + + torch.where(tangent_lower <= 1.5*torch.pi, + BoundSin.xl_lower_tb[1][indices_lb] + lb_cycles, + BoundSin.xu_lower_tb[1][indices_ub] + ub_cycles)) + d_upper = BoundSin.d_func(tangent_upper) + b_upper = BoundSin.func(tangent_upper) - d_upper * (tangent_upper + + torch.where(tangent_upper <= 2.5*torch.pi, + BoundSin.xl_upper_tb[1][indices_lb] + lb_cycles, + BoundSin.xu_upper_tb[1][indices_ub] + ub_cycles)) + return d_lower, b_lower, d_upper, b_upper + + def forward(self, x): + return torch.sin(x) + + def interval_propagate(self, *v): + # Check if a point is in [l, u], considering the 2pi period + def check_crossing(ll, uu, point): + return ((((uu - point) / (2 * torch.pi)).floor() + - ((ll - point) / (2 * torch.pi)).floor()) > 0).to(h_Ls.dtype) + h_L, h_U = v[0][0], v[0][1] + h_Ls, h_Us = self.forward(h_L), self.forward(h_U) + # If crossing pi/2, then max is fixed 1.0 + max_mask = check_crossing(h_L, h_U, self.ibp_max_point) + # If crossing pi*3/2, then min is fixed -1.0 + min_mask = check_crossing(h_L, h_U, self.ibp_min_point) + ub = torch.max(h_Ls, h_Us) + ub = max_mask + (1 - max_mask) * ub + lb = torch.min(h_Ls, h_Us) + lb = - min_mask + (1 - min_mask) * lb + return lb, ub + + def bound_relax_impl(self, lb, ub): + dtype = lb.dtype + + ub = torch.max(ub, lb + 1e-8) + + # Case 1: Connect the two points as a line + sub = self.func(ub) + slb = self.func(lb) + mid = (sub + slb) / 2. + smid = self.func((ub + lb) / 2) + gap = smid - mid + + min_preact = 1e-3 + mask_close = (ub - lb) < min_preact + case1_line_slope = torch.where(mask_close, self.d_act_func(ub), + (sub - slb) / (ub - lb).clamp(min=1e-10)) + case1_line_bias = slb - case1_line_slope * lb + # Check if there are crossings between the line and the sin function. + grad_crossings = self.get_intersection(lb, ub, case1_line_slope, theta=0.5 * torch.pi) + # If there is no crossing, then we can connect the two points together as a lower/upper bound. + use_line = grad_crossings == 1 + # Connected line is the upper bound. + upper_use_line = torch.logical_and(gap < 0, use_line) + # Connected line is the lower bound. + lower_use_line = torch.logical_and(gap >= 0, use_line) + + # Case 2: we will try the global lower/upper bounds at lb and ub. + # For the points and lb and ub, we can construct both lower and upper bounds. + (case_2_lower_slope, case_2_lower_bias, + case_2_upper_slope, case_2_upper_bias) = self.get_bound_tb(lb, ub) + + # Finally, choose between case 1 and case 2. + lower_use_line = lower_use_line.to(dtype) + not_lower_use_line = 1. - lower_use_line + upper_use_line = upper_use_line.to(dtype) + not_upper_use_line = 1. - upper_use_line + lower_slope = lower_use_line * case1_line_slope + not_lower_use_line * case_2_lower_slope + lower_bias = lower_use_line * case1_line_bias + not_lower_use_line * case_2_lower_bias + upper_slope = upper_use_line * case1_line_slope + not_upper_use_line * case_2_upper_slope + upper_bias = upper_use_line * case1_line_bias + not_upper_use_line * case_2_upper_bias + return lower_slope, lower_bias, upper_slope, upper_bias + + +class BoundCos(BoundSin): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + + self.ibp_max_point = 0.0 + self.ibp_min_point = torch.pi + + def forward(self, x): + return torch.cos(x) + + def bound_relax(self, x, init=False, dim_opt=None): + if init: + self.init_linear_relaxation(x, dim_opt) + # Shift the input by half_pi, and shifting the linear bounds back. + half_pi = 0.5 * torch.pi + lb = x.lower + half_pi + ub = x.upper + half_pi + self.generate_inflections(lb, ub) + self.branch_input_domain(lb, ub) + self.bound_relax_impl_sigmoid(lb, ub, self.act_func, self.d_act_func) + if self.opt_stage is None and self.sigmoid_like_mask.all(): + self.lb = self.lw * half_pi + self.lb + self.ub = self.uw * half_pi + self.ub + return + lower_slope, lower_bias, upper_slope, upper_bias = self.bound_relax_impl(lb, ub) + self.lw = self.lw * self.sigmoid_like_mask + self.branch_mask * lower_slope + self.lb = (self.sigmoid_like_mask * (self.lw * half_pi + self.lb) + + self.branch_mask * (lower_slope * half_pi + lower_bias)) + self.uw = self.uw * self.sigmoid_like_mask + self.branch_mask * upper_slope + self.ub = (self.sigmoid_like_mask * (self.uw * half_pi + self.ub) + + self.branch_mask * (upper_slope * half_pi + upper_bias)) + + +class BoundSec(BoundActivation): + def __init__(self, attr=None, inputs=None, output_index=0, options=None): + super().__init__(attr, inputs, output_index, options) + self.ibp_intermediate = True + + def forward(self, x): + return 1. / torch.cos(x) + + def bound_relax(self, x, init=False): + assert x.lower.min() > -torch.pi / 2 + assert x.upper.max() < torch.pi / 2 + + x_L = x.lower + x_U = x.upper + y_L = self.forward(x_L) + y_U = self.forward(x_U) + mask_close = x_U - x_L < 1e-8 + upper_k = torch.where( + mask_close, + y_L * torch.tan(x_L), + (y_U - y_L) / (x_U - x_L).clamp(min=1e-8) + ) + self.uw = upper_k + self.ub = -upper_k * x_L + y_L + + mid = (x_L + x_U) / 2 + y_mid = self.forward(mid) + lower_k = y_mid * torch.tan(mid) + self.lw = lower_k + self.lb = -lower_k * mid + y_mid + + def interval_propagate(self, *v): + h_L, h_U = v[0][0], v[0][1] + assert h_L.min() > -torch.pi / 2 + assert h_U.max() < torch.pi / 2 + y_L = self.forward(h_L) + y_U = self.forward(h_U) + lower = (h_U < 0) * (y_U - 1) + (h_L > 0) * (y_L - 1) + 1 + upper = torch.max(y_L, y_U) + return lower, upper diff --git a/auto_LiRPA/opt_pruner.py b/auto_LiRPA/opt_pruner.py new file mode 100644 index 0000000..256eeae --- /dev/null +++ b/auto_LiRPA/opt_pruner.py @@ -0,0 +1,312 @@ +"""Pruning during the optimization.""" + +import time + +import torch + + +class OptPruner: + + def __init__(self, x, threshold, multi_spec_keep_func, loss_reduction_func, + decision_thresh, fix_interm_bounds, + epsilon_over_decision_thresh): + self.x = x + self.threshold = threshold + self.multi_spec_keep_func = multi_spec_keep_func + self.loss_reduction_func = loss_reduction_func + self.decision_thresh = decision_thresh + self.fix_interm_bounds = fix_interm_bounds + self.epsilon_over_decision_thresh = epsilon_over_decision_thresh + + # For computing the positive domain ratio + self.original_size = x[0].shape[0] + self.pruning_in_iteration = False + self.preserve_mask = None + self.preserve_mask_next = None + self.time = 0 + + # For holding full-sized alphas + self.cached_alphas = {} + + def prune(self, x, C, ret_l, ret_u, ret, full_l, full_ret_l, full_ret_u, + full_ret, interm_bounds, aux_reference_bounds, + stop_criterion_func, bound_lower): + # positive domains may already be filtered out, so we use all domains - + # negative domains to compute + # FIXME Only using ret_l but not ret_u. + if self.decision_thresh is not None and ret_l is not None: + if (isinstance(self.decision_thresh, torch.Tensor) + and self.decision_thresh.numel() > 1 + and self.preserve_mask is not None): + if self.decision_thresh.shape[-1] == 1: + # single spec with pruned domains + negative_domain = ( + ret_l.view(-1) + <= self.decision_thresh[self.preserve_mask].view(-1) + ).sum() + else: + # multiple spec with pruned domains + negative_domain = self.multi_spec_keep_func( + ret_l <= self.decision_thresh[self.preserve_mask]).sum() + else: + if ret_l.shape[-1] == 1: + # single spec + negative_domain = ( + ret_l.view(-1) <= self.decision_thresh.view(-1)).sum() + else: + # multiple spec + negative_domain = self.multi_spec_keep_func( + ret_l <= self.decision_thresh).sum() + positive_domain_num = self.original_size - negative_domain + else: + positive_domain_num = -1 + positive_domain_ratio = float( + positive_domain_num) / float(self.original_size) + # threshold is 10% by default + self.next_iter_pruning_in_iteration = ( + self.decision_thresh is not None + and positive_domain_ratio > self.threshold) + + if self.pruning_in_iteration: + stime = time.time() + self.get_preserve_mask(ret_l) + # prune C + if C is not None and C.shape[0] == x[0].shape[0]: + C = C[self.now_preserve_mask] # means C is also batch specific + # prune x + x, pre_prune_size = self._prune_x(x) + # prune bounds + ret_prune = self._prune_bounds_by_mask( + ret_l, ret_u, ret, + interm_bounds, aux_reference_bounds, pre_prune_size) + full_l, full_ret_l, full_ret_u, full_ret = ret_prune + self.time += time.time() - stime + + stop_criterion = stop_criterion_func( + full_ret_l) if bound_lower else stop_criterion_func(-full_ret_u) + if (type(stop_criterion) != bool and stop_criterion.numel() > 1 + and self.pruning_in_iteration): + stop_criterion = stop_criterion[self.preserve_mask] + + return (x, C, full_l, full_ret_l, full_ret_u, + full_ret, stop_criterion) + + def prune_idx(self, idx_mask, idx, x): + if self.pruning_in_iteration: + # local sparse index of preserved samples where + # idx == true + local_idx = idx_mask[self.preserve_mask].nonzero().view(-1) + # idx is global sparse index of preserved samples where + # idx == true + new_idx = torch.zeros_like( + idx_mask, dtype=torch.bool, device=x[0].device) + new_idx[self.preserve_mask] = idx_mask[self.preserve_mask] + idx = new_idx.nonzero().view(-1) + reference_idx = local_idx + else: + reference_idx = idx + return reference_idx, idx + + def next_iter(self): + if self.pruning_in_iteration: + self.preserve_mask = self.preserve_mask_next + if (not self.pruning_in_iteration + and self.next_iter_pruning_in_iteration): + # init preserve_mask etc + self.preserve_mask = torch.arange( + 0, self.x[0].shape[0], device=self.x[0].device, dtype=torch.long) + self.pruning_in_iteration = True + + def update_best(self, full_ret_l, full_ret_u, best_ret): + if self.pruning_in_iteration: + # overwrite pruned cells in best_ret by threshold + eps + fin_l, fin_u = best_ret + if fin_l is not None: + new_fin_l = full_ret_l + new_fin_l[self.preserve_mask] = fin_l[self.preserve_mask] + fin_l = new_fin_l + if fin_u is not None: + new_fin_u = full_ret_u + new_fin_u[self.preserve_mask] = fin_u[self.preserve_mask] + fin_u = new_fin_u + best_ret = (fin_l, fin_u) + return best_ret + + def update_ratio(self, full_l, full_ret_l): + if self.decision_thresh is not None and full_l.numel() > 0: + stime = time.time() + with torch.no_grad(): + if isinstance(self.decision_thresh, torch.Tensor): + if self.decision_thresh.shape[-1] == 1: + neg_domain_num = torch.sum( + full_ret_l.view(-1) <= self.decision_thresh.view(-1) + ).item() + else: + neg_domain_num = torch.sum(self.multi_spec_keep_func( + full_ret_l <= self.decision_thresh)).item() + else: + if full_l.shape[-1] == 1: + neg_domain_num = torch.sum( + full_ret_l.view(-1) <= self.decision_thresh).item() + else: + neg_domain_num = torch.sum(self.multi_spec_keep_func( + full_ret_l <= self.decision_thresh)).item() + now_pruning_ratio = ( + 1.0 - float(neg_domain_num) / float(full_l.shape[0])) + print('pruning_in_iteration open status:', + self.pruning_in_iteration) + print('ratio of positive domain =', + full_l.shape[0] - neg_domain_num, + '/', full_l.numel(), '=', now_pruning_ratio) + self.time += time.time() - stime + print('pruning-in-iteration extra time:', self.time) + + @torch.no_grad() + def _prune_x(self, x): + """ + Prune x by given now_preserve_mask. + """ + x = list(x) + pre_prune_size = x[0].shape[0] + x[0].data = x[0][self.now_preserve_mask].data + if hasattr(x[0], 'ptb'): + if x[0].ptb.x_L is not None: + x[0].ptb.x_L = x[0].ptb.x_L[self.now_preserve_mask] + if x[0].ptb.x_U is not None: + x[0].ptb.x_U = x[0].ptb.x_U[self.now_preserve_mask] + x = tuple(x) + + return x, pre_prune_size + + @torch.no_grad() + def _prune_bounds_by_mask(self, ret_l, ret_u, ret, interm_bounds, + aux_reference_bounds, pre_prune_size): + """ + Prune bounds by given now_preserve_mask. + """ + full_ret_l, full_l = self._recover_bounds_to_full_batch(ret_l) + full_ret_u, full_u = self._recover_bounds_to_full_batch(ret_u) + + full_ret = (full_ret_l, full_ret_u) + ret[2:] + + if self.fix_interm_bounds: + interval_to_prune = interm_bounds + else: + interval_to_prune = None + if interval_to_prune is not None: + for k, v in interval_to_prune.items(): + interm_interval_l, interm_interval_r = v[0], v[1] + if interm_interval_l.shape[0] == pre_prune_size: + # the first dim is batch size and matches preserve mask + interm_interval_l = interm_interval_l[ + self.now_preserve_mask] + if interm_interval_r.shape[0] == pre_prune_size: + # the first dim is batch size and matches preserve mask + interm_interval_r = interm_interval_r[ + self.now_preserve_mask] + interval_to_prune[k] = [interm_interval_l, interm_interval_r] + + if aux_reference_bounds is not None: + for k in aux_reference_bounds: + aux_ref_l, aux_ref_r = aux_reference_bounds[k] + if aux_ref_l.shape[0] == pre_prune_size: + # the first dim is batch size and matches the preserve mask + aux_ref_l = aux_ref_l[self.now_preserve_mask] + if aux_ref_r.shape[0] == pre_prune_size: + # the first dim is batch size and matches the preserve mask + aux_ref_r = aux_ref_r[self.now_preserve_mask] + aux_reference_bounds[k] = [aux_ref_l, aux_ref_r] + + # update the global mask here for possible next iteration + self.preserve_mask_next = self.preserve_mask[self.now_preserve_mask] + + return full_l, full_ret_l, full_ret_u, full_ret + + @torch.no_grad() + def get_preserve_mask(self, ret_l): + """ + Get preserve mask by decision_thresh to filter out the satisfied bounds. + """ + if (isinstance(self.decision_thresh, torch.Tensor) + and self.decision_thresh.numel() > 1): + if self.decision_thresh.shape[-1] == 1: + self.now_preserve_mask = ( + ret_l <= self.decision_thresh[self.preserve_mask] + ).view(-1).nonzero().view(-1) + else: + self.now_preserve_mask = self.multi_spec_keep_func( + ret_l <= self.decision_thresh[self.preserve_mask] + ).nonzero().view(-1) + else: + if self.decision_thresh.shape[-1] == 1: + self.now_preserve_mask = ( + ret_l <= self.decision_thresh).view(-1).nonzero().view(-1) + else: + self.now_preserve_mask = self.multi_spec_keep_func( + ret_l <= self.decision_thresh).nonzero().view(-1) + + def _recover_bounds_to_full_batch(self, ret): + """ + Recover lower and upper bounds to full batch size so that later we can + directly update using the full batch size of l and u. + """ + if ret is not None: + if (isinstance(self.decision_thresh, torch.Tensor) + and self.decision_thresh.numel() > 1): + full_ret = ( + self.decision_thresh.clone().to(ret.device).type(ret.dtype) + + self.epsilon_over_decision_thresh) + else: + num_decision_thresh = self.decision_thresh + if isinstance(num_decision_thresh, torch.Tensor): + num_decision_thresh = num_decision_thresh.item() + full_ret = torch.full( + (self.original_size,) + tuple(ret.shape[1:]), + fill_value=(num_decision_thresh + + self.epsilon_over_decision_thresh), + device=ret.device, dtype=ret.dtype) + full_ret[self.preserve_mask] = ret + if full_ret.shape[1] > 1: + full_reduced_ret = self.loss_reduction_func(full_ret) + else: + full_reduced_ret = full_ret + else: + full_ret = full_reduced_ret = None + + return full_ret, full_reduced_ret + + def cache_full_sized_alpha(self, optimizable_activations: list): + """ + When preserve mask is in use, cache the full-sized alphas in self.cached_alphas, + and rewrite the alphas in nodes according to the preserve mask. + The full-sized alphas will be recovered back to nodes after compute_bounds, + via the function named recover_full_sized_alphas() + :param optimizable_activations: list of nodes that may have slope alphas as optimizable variables + :return: None + """ + if self.pruning_in_iteration: + for act in optimizable_activations: + if act.name in self.cached_alphas: + self.cached_alphas[act.name].clear() + self.cached_alphas[act.name] = {} + if act.alpha is not None: + for start_node in act.alpha: + # cached alphas and alphas stored in nodes should share the same memory space + self.cached_alphas[act.name][start_node] = act.alpha[start_node] + act.alpha[start_node] = act.alpha[start_node][:, :, self.preserve_mask] + + def recover_full_sized_alpha(self, optimizable_activations: list): + """ + After bound computation, recover the full-sized alphas back to nodes. + :param optimizable_activations: ist of nodes that may have slope alphas as optimizable variables + :return: None + """ + if self.pruning_in_iteration: + for act in optimizable_activations: + for start_node in self.cached_alphas[act.name]: + act.alpha[start_node] = self.cached_alphas[act.name][start_node] + + def clean_full_sized_alpha_cache(self): + for act_node in self.cached_alphas: + self.cached_alphas[act_node].clear() + self.cached_alphas.clear() diff --git a/auto_LiRPA/optimize_graph.py b/auto_LiRPA/optimize_graph.py new file mode 100644 index 0000000..0324fdc --- /dev/null +++ b/auto_LiRPA/optimize_graph.py @@ -0,0 +1,99 @@ +"""Optimize the graph to merge nodes and remove unnecessary ones. + +Initial and experimental code only. +""" + +from auto_LiRPA.bound_ops import (BoundActivation, BoundMul, BoundSqr, BoundDiv, + BoundPow, BoundReciprocal, BoundBuffers, + BoundCos, BoundSec) + +from auto_LiRPA.utils import logger + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule + + +def _optimize_graph(self: 'BoundedModule'): + """Optimize the graph to remove some unnecessary nodes.""" + merge_identical_act(self) + convert_sqr(self) + div_to_mul(self) + merge_sec(self) + + if self.bound_opts['optimize_graph']['optimizer'] is not None: + # Use the custom graph optimizer + self.bound_opts['optimize_graph']['optimizer'](self) + + for node in list(self.nodes()): + if (not node.output_name + and node.name != self.final_name + and node.name not in self.root_names): + self.delete_node(node) + + +def merge_sec(model: 'BoundedModule'): + nodes = list(model.nodes()) + for node in nodes: + if type(node) == BoundReciprocal and type(node.inputs[0]) == BoundCos: + node_new = BoundSec(inputs=[node.inputs[0].inputs[0]]) + node_new.name = f'{node.inputs[0].name}/sec' + model.add_nodes([node_new]) + model.replace_node(node, node_new) + + +def div_to_mul(model: 'BoundedModule'): + nodes = list(model.nodes()) + for node in nodes: + if type(node) == BoundDiv: + logger.debug('Replacing BoundDiv node: %s', node) + node_reciprocal = BoundReciprocal(inputs=[node.inputs[1]]) + node_reciprocal.name = f'{node.name}/reciprocal' + model.add_nodes([node_reciprocal]) + node_mul = BoundMul(inputs=[node.inputs[0], node_reciprocal], + options=model.bound_opts) + node_mul.name = f'{node.name}/mul' + model.add_nodes([node_mul]) + model.replace_node(node, node_mul) + + +def convert_sqr(model: 'BoundedModule'): + """Replace BoundMul or Bound Pow with BoundSqr if applicable. + + 1. If the two inputs nodes of a BoundMul node are the same, use BoundSqr. + 2. Pow(x, 2) can be replaced with BoundSqr. + """ + nodes = list(model.nodes()) + for node in nodes: + replace = False + if type(node) == BoundMul and node.inputs[0] == node.inputs[1]: + replace = True + elif type(node) == BoundPow: + if (isinstance(node.inputs[1], BoundBuffers) + and node.inputs[1].buffer == 2): + replace = True + if replace: + node_new = BoundSqr(inputs=[node.inputs[0]]) + node_new.name = f'{node.name}/sqr' + model.add_nodes([node_new]) + logger.debug('Replaceing %s with %s', node, node_new) + model.replace_node(node, node_new) + + +def merge_identical_act(model: 'BoundedModule'): + """Merge identical BoundActivation""" + nodes = list(model.nodes()) + merged = [False] * len(nodes) + for i in range(len(nodes)): + if (not merged[i] + and isinstance(nodes[i], BoundActivation) + and len(nodes[i].inputs) == 1): + for j in range(i + 1, len(nodes)): + if (not merged[j] + and type(nodes[j]) == type(nodes[i]) + and len(nodes[i].inputs) == 1): + if nodes[i].inputs[0] == nodes[j].inputs[0]: + logger.debug('Merging node %s to %s', nodes[j], nodes[i]) + model.replace_node(nodes[j], nodes[i]) + merged[j] = True + diff --git a/auto_LiRPA/optimized_bounds.py b/auto_LiRPA/optimized_bounds.py index 00e6104..b6ba75e 100644 --- a/auto_LiRPA/optimized_bounds.py +++ b/auto_LiRPA/optimized_bounds.py @@ -1,111 +1,122 @@ import time import os -import warnings from collections import OrderedDict from contextlib import ExitStack +from auto_LiRPA.operators.leaf import BoundInput import torch from torch import optim +from .beta_crown import print_optimized_beta from .cuda_utils import double2float -from .utils import logger +from .utils import logger, reduction_sum, multi_spec_keep_func_all +from .opt_pruner import OptPruner + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule + + +default_optimize_bound_args = { + 'enable_alpha_crown': True, # Enable optimization of alpha. + 'enable_beta_crown': False, # Enable beta split constraint. + + 'apply_output_constraints_to': None, # Enable optimization w.r.t. output constraints. + + 'iteration': 20, # Number of alpha/beta optimization iterations. + # Share some alpha variables to save memory at the cost of slightly + # looser bounds. + 'use_shared_alpha': False, + # Optimizer used for alpha and beta optimization. + 'optimizer': 'adam', + # Save best results of alpha/beta/bounds during optimization. + 'keep_best': True, + # Only optimize bounds of last layer during alpha/beta CROWN. + 'fix_interm_bounds': True, + # Learning rate for the optimizable parameter alpha in alpha-CROWN. + 'lr_alpha': 0.5, + # Learning rate for the optimizable parameter beta in beta-CROWN. + 'lr_beta': 0.05, + 'lr_cut_beta': 5e-3, # Learning rate for optimizing cut betas. + # Initial alpha variables by calling CROWN once. + 'init_alpha': True, + 'lr_coeffs': 0.01, # Learning rate for coeffs for refinement + # Layers to be refined, separated by commas. + # -1 means preactivation before last activation. + 'intermediate_refinement_layers': [-1], + # When batch size is not 1, this reduction function is applied to + # reduce the bounds into a scalar. + 'loss_reduction_func': reduction_sum, + # Criteria function of early stop. + 'stop_criterion_func': lambda x: False, + # Learning rate decay factor during bounds optimization. + 'lr_decay': 0.98, + # Number of iterations that we will start considering early stop + # if tracking no improvement. + 'early_stop_patience': 10, + # Start to save optimized best bounds + # when current_iteration > int(iteration*start_save_best) + 'start_save_best': 0.5, + # Use double fp (float64) at the last iteration in alpha/beta CROWN. + 'use_float64_in_last_iteration': False, + # Prune verified domain within iteration. + 'pruning_in_iteration': False, + # Percentage of the minimum domains that can apply pruning. + 'pruning_in_iteration_threshold': 0.2, + # For specification that will output multiple bounds for one + # property, we use this function to prune them. + 'multi_spec_keep_func': multi_spec_keep_func_all, + # Use the newly fixed loss function. By default, it is set to False + # for compatibility with existing use cases. + # Try to ensure that the parameters always match with the optimized bounds. + 'deterministic': False, +} + + +def opt_reuse(self: 'BoundedModule'): + for node in self.get_enabled_opt_act(): + node.opt_reuse() + + +def opt_no_reuse(self: 'BoundedModule'): + for node in self.get_enabled_opt_act(): + node.opt_no_reuse() def _set_alpha(optimizable_activations, parameters, alphas, lr): - """ - Set best_alphas, alphas and parameters list - """ + """Set best_alphas, alphas and parameters list.""" for node in optimizable_activations: alphas.extend(list(node.alpha.values())) node.opt_start() # Alpha has shape (2, output_shape, batch_dim, node_shape) parameters.append({'params': alphas, 'lr': lr, 'batch_dim': 2}) # best_alpha is a dictionary of dictionary. Each key is the alpha variable - # for one relu layer, and each value is a dictionary contains all relu - # layers after that layer as keys. + # for one activation layer, and each value is a dictionary contains all + # activation layers after that layer as keys. best_alphas = OrderedDict() for m in optimizable_activations: best_alphas[m.name] = {} for alpha_m in m.alpha: best_alphas[m.name][alpha_m] = m.alpha[alpha_m].detach().clone() - # We will directly replace the dictionary for each relu layer after + # We will directly replace the dictionary for each activation layer after # optimization, so the saved alpha might not have require_grad=True. m.alpha[alpha_m].requires_grad_() return best_alphas -def _set_beta( - self, relus, optimizable_activations, single_node_split, - enable_opt_interm_bounds, betas, opt_coeffs, parameters, - lr_coeffs, opt_bias, lr_beta, lr_cut_beta, cutter, dense_coeffs_mask): +def _set_gammas(nodes, parameters): """ - Set betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases - and best_biases. + Adds gammas to parameters list """ - coeffs = best_coeffs = biases = best_biases = None - if len(relus) != len(optimizable_activations): - warnings.warn( - 'Only relu split is supported so far, this model contains other ' - 'optimizable activations that may not apply split.') - - if single_node_split: - for node in relus: - if enable_opt_interm_bounds and node.sparse_beta is not None: - for key in node.sparse_beta.keys(): - if node.sparse_beta[key] is not None: - betas.append(node.sparse_beta[key]) - else: - if node.sparse_beta is not None: - betas.append(node.sparse_beta) - else: - betas = self.beta_params + self.single_beta_params - if opt_coeffs: - coeffs = [dense_coeffs['dense'] - for dense_coeffs in self.split_dense_coeffs_params - ] + self.coeffs_params - dense_coeffs_mask = [dense_coeffs['mask'] - for dense_coeffs in self.split_dense_coeffs_params] - parameters.append({'params': coeffs, 'lr': lr_coeffs}) - best_coeffs = [coeff.detach().clone() for coeff in coeffs] - if opt_bias: - biases = self.bias_params - parameters.append({'params': biases, 'lr': lr_coeffs}) - best_biases = [bias.detach().clone() for bias in biases] - - # Beta has shape (batch, max_splits_per_layer) - parameters.append({'params': betas, 'lr': lr_beta, 'batch_dim': 0}) - - if self.cut_used: - # also need to optimize cut betas - parameters.append({'params': self.cut_beta_params, - 'lr': lr_cut_beta, 'batch_dim': 0}) - betas = betas + self.cut_beta_params - - if enable_opt_interm_bounds and betas: - best_betas = OrderedDict() - for m in optimizable_activations: - best_betas[m.name] = {} - for beta_m, beta in m.sparse_beta.items(): - best_betas[m.name][beta_m] = beta.detach().clone() - if self.cut_used: - best_betas['cut'] = [] - for general_betas in self.cut_beta_params: - best_betas['cut'].append(general_betas.detach().clone()) - else: - best_betas = [b.detach().clone() for b in betas] - - if self.cut_used and getattr(cutter, 'opt', False): - parameters.append(cutter.get_parameters()) - - return ( - betas, best_betas, coeffs, dense_coeffs_mask, best_coeffs, biases, - best_biases) - + gammas = [] + for node in nodes: + if hasattr(node, 'gammas'): + gammas.append(node.gammas) + gamma_lr = 0.1 + parameters.append({'params': gammas, 'lr': gamma_lr}) def _save_ret_first_time(bounds, fill_value, x, best_ret): - """ - Save results at the first iteration to best_ret - """ + """Save results at the first iteration to best_ret.""" if bounds is not None: best_bounds = torch.full_like( bounds, fill_value=fill_value, device=x[0].device, dtype=x[0].dtype) @@ -120,138 +131,7 @@ def _save_ret_first_time(bounds, fill_value, x, best_ret): return best_bounds -@torch.no_grad() -def _get_preserve_mask( - decision_thresh, ret_l, preserve_mask, multi_spec_keep_func): - """ - Get preserve mask by decision_thresh to filter out the satisfied bounds. - """ - if (isinstance(decision_thresh, torch.Tensor) - and decision_thresh.numel() > 1): - if decision_thresh.shape[-1] == 1: - now_preserve_mask = ( - ret_l <= decision_thresh[preserve_mask] - ).view(-1).nonzero().view(-1) - else: - now_preserve_mask = multi_spec_keep_func( - ret_l <= decision_thresh[preserve_mask]).nonzero().view(-1) - else: - if decision_thresh.shape[-1] == 1: - now_preserve_mask = ( - ret_l <= decision_thresh).view(-1).nonzero().view(-1) - else: - now_preserve_mask = multi_spec_keep_func( - ret_l <= decision_thresh).nonzero().view(-1) - - return now_preserve_mask - - -def _recover_bounds_to_full_batch( - ret, decision_thresh, epsilon_over_decision_thresh, original_size, - preserve_mask, loss_reduction_func): - """ - Recover lower and upper bounds to full batch size so that later we can - directly update using the full batch size of l and u. - """ - if ret is not None: - if (isinstance(decision_thresh, torch.Tensor) - and decision_thresh.numel() > 1): - full_ret = (decision_thresh.clone().to(ret.device).type(ret.dtype) - + epsilon_over_decision_thresh) - else: - num_decision_thresh = decision_thresh - if isinstance(num_decision_thresh, torch.Tensor): - num_decision_thresh = num_decision_thresh.item() - full_ret = torch.full( - (original_size,) + tuple(ret.shape[1:]), - fill_value=num_decision_thresh + epsilon_over_decision_thresh, - device=ret.device, dtype=ret.dtype) - full_ret[preserve_mask] = ret - if full_ret.shape[1] > 1: - full_reduced_ret = loss_reduction_func(full_ret) - else: - full_reduced_ret = full_ret - else: - full_ret = full_reduced_ret = None - - return full_ret, full_reduced_ret - - -@torch.no_grad() -def _prune_bounds_by_mask( - now_preserve_mask, decision_thresh, ret_l, ret_u, ret, preserve_mask, - epsilon_over_decision_thresh, original_size, loss_reduction_func, - beta, intermediate_beta_enabled, - fix_intermediate_layer_bounds, intermediate_layer_bounds, - aux_reference_bounds, partial_intermediate_layer_bounds, - pre_prune_size): - """ - Prune bounds by given now_preserve_mask. - """ - full_ret_l, full_l = _recover_bounds_to_full_batch( - ret_l, decision_thresh, epsilon_over_decision_thresh, - original_size, preserve_mask, loss_reduction_func) - - full_ret_u, full_u = _recover_bounds_to_full_batch( - ret_u, decision_thresh, epsilon_over_decision_thresh, - original_size, preserve_mask, loss_reduction_func) - - full_ret = (full_ret_l, full_ret_u) + ret[2:] - - if beta and intermediate_beta_enabled: - # prune the partial_intermediate_layer_bounds - interval_to_prune = partial_intermediate_layer_bounds - elif fix_intermediate_layer_bounds: - interval_to_prune = intermediate_layer_bounds - else: - interval_to_prune = None - if interval_to_prune is not None: - for k, v in interval_to_prune.items(): - interm_interval_l, interm_interval_r = v[0], v[1] - if interm_interval_l.shape[0] == pre_prune_size: - # the first dim is batch size and matches preserve mask - interm_interval_l = interm_interval_l[now_preserve_mask] - if interm_interval_r.shape[0] == pre_prune_size: - # the first dim is batch size and matches preserve mask - interm_interval_r = interm_interval_r[now_preserve_mask] - interval_to_prune[k] = [interm_interval_l, interm_interval_r] - - if aux_reference_bounds is not None: - for k in aux_reference_bounds: - aux_ref_l, aux_ref_r = aux_reference_bounds[k] - if aux_ref_l.shape[0] == pre_prune_size: - # the first dim is batch size and matches the preserve mask - aux_ref_l = aux_ref_l[now_preserve_mask] - if aux_ref_r.shape[0] == pre_prune_size: - # the first dim is batch size and matches the preserve mask - aux_ref_r = aux_ref_r[now_preserve_mask] - aux_reference_bounds[k] = [aux_ref_l, aux_ref_r] - - # update the global mask here for possible next iteration - preserve_mask_next = preserve_mask[now_preserve_mask] - - return full_l, full_ret_l, full_u, full_ret_u, full_ret, preserve_mask_next - - -@torch.no_grad() -def _prune_x(x, now_preserve_mask): - """ - Prune x by given now_preserve_mask. - """ - x = list(x) - pre_prune_size = x[0].shape[0] - x[0].data = x[0][now_preserve_mask].data - if hasattr(x[0], 'ptb'): - if x[0].ptb.x_L is not None: - x[0].ptb.x_L = x[0].ptb.x_L[now_preserve_mask] - if x[0].ptb.x_U is not None: - x[0].ptb.x_U = x[0].ptb.x_U[now_preserve_mask] - x = tuple(x) - - return x, pre_prune_size - - -def _to_float64(self, C, x, aux_reference_bounds, intermediate_layer_bounds): +def _to_float64(self: 'BoundedModule', C, x, aux_reference_bounds, interm_bounds): """ Transfer variables to float64 only in the last iteration to help alleviate floating point error. @@ -262,14 +142,14 @@ def _to_float64(self, C, x, aux_reference_bounds, intermediate_layer_bounds): # best_intermediate_bounds is linked to aux_reference_bounds! # we only need call .to() for one of them self._to(aux_reference_bounds, torch.float64, inplace=True) - intermediate_layer_bounds = self._to( - intermediate_layer_bounds, torch.float64) + interm_bounds = self._to( + interm_bounds, torch.float64) - return C, x, intermediate_layer_bounds + return C, x, interm_bounds -def _to_default_dtype( - self, x, total_loss, full_ret, ret, best_intermediate_bounds, return_A): +def _to_default_dtype(self: 'BoundedModule', x, total_loss, full_ret, ret, + best_intermediate_bounds, return_A): """ Switch back to default precision from float64 typically to adapt to afterwards operations. @@ -294,15 +174,16 @@ def _to_default_dtype( return total_loss, x, full_ret -def _get_idx_mask(idx, full_ret_bound, best_ret_bound): +def _get_idx_mask(idx, full_ret_bound, best_ret_bound, loss_reduction_func): """Get index for improved elements.""" assert idx in [0, 1], ( '0 means updating lower bound, 1 means updating upper bound') if idx == 0: - idx_mask = (full_ret_bound > best_ret_bound).any(dim=1).view(-1) + idx_mask = (loss_reduction_func(full_ret_bound) + > loss_reduction_func(best_ret_bound)).view(-1) else: - idx_mask = (full_ret_bound < best_ret_bound).any(dim=1).view(-1) - + idx_mask = (loss_reduction_func(full_ret_bound) + < loss_reduction_func(best_ret_bound)).view(-1) improved_idx = None if idx_mask.any(): # we only pick up the results improved in a batch @@ -310,103 +191,96 @@ def _get_idx_mask(idx, full_ret_bound, best_ret_bound): return idx_mask, improved_idx -def _update_best_ret( - full_ret_bound, best_ret_bound, full_ret, best_ret, need_update, idx): +def _update_best_ret(full_ret_bound, best_ret_bound, full_ret, best_ret, + need_update, loss_reduction_func, idx, deterministic=False): """Update best_ret_bound and best_ret by comparing with new results.""" assert idx in [0, 1], ( '0 means updating lower bound, 1 means updating upper bound') - idx_mask, improved_idx = _get_idx_mask(idx, full_ret_bound, best_ret_bound) + idx_mask, improved_idx = _get_idx_mask( + idx, full_ret_bound, best_ret_bound, loss_reduction_func) if improved_idx is not None: need_update = True - if idx == 0: - best_ret_bound[improved_idx] = torch.maximum( + compare = torch.max if idx == 0 else torch.min + if not deterministic: + best_ret_bound[improved_idx] = compare( full_ret_bound[improved_idx], best_ret_bound[improved_idx]) - if full_ret[idx] is not None: - best_ret[idx][improved_idx] = torch.maximum( - full_ret[idx][improved_idx], best_ret[idx][improved_idx]) else: - best_ret_bound[improved_idx] = torch.minimum( - full_ret_bound[improved_idx], best_ret_bound[improved_idx]) - if full_ret[idx] is not None: - best_ret[idx][improved_idx] = torch.minimum( - full_ret[idx][improved_idx], best_ret[idx][improved_idx]) + best_ret_bound[improved_idx] = full_ret_bound[improved_idx] + if full_ret[idx] is not None: + if not deterministic: + best_ret[idx][improved_idx] = compare( + full_ret[idx][improved_idx], + best_ret[idx][improved_idx]) + else: + best_ret[idx][improved_idx] = full_ret[idx][improved_idx] - return best_ret_bound, best_ret, idx_mask, improved_idx, need_update + return best_ret_bound, best_ret, need_update, idx_mask, improved_idx def _update_optimizable_activations( - optimizable_activations, intermediate_layer_bounds, - fix_intermediate_layer_bounds, best_intermediate_bounds, - reference_idx, idx, alpha, best_alphas): + optimizable_activations, interm_bounds, + fix_interm_bounds, best_intermediate_bounds, + reference_idx, idx, alpha, best_alphas, deterministic): """ Update bounds and alpha of optimizable_activations. """ for node in optimizable_activations: # Update best intermediate layer bounds only when they are optimized. - # If they are already fixed in intermediate_layer_bounds, then do + # If they are already fixed in interm_bounds, then do # nothing. - if (intermediate_layer_bounds is None - or node.inputs[0].name not in intermediate_layer_bounds - or not fix_intermediate_layer_bounds): - best_intermediate_bounds[node.name][0][idx] = torch.max( - best_intermediate_bounds[node.name][0][idx], - node.inputs[0].lower[reference_idx]) - best_intermediate_bounds[node.name][1][idx] = torch.min( - best_intermediate_bounds[node.name][1][idx], - node.inputs[0].upper[reference_idx]) - + if (interm_bounds is None + or node.inputs[0].name not in interm_bounds + or not fix_interm_bounds): + if deterministic: + best_intermediate_bounds[node.name][0][idx] = node.inputs[0].lower[reference_idx] + best_intermediate_bounds[node.name][1][idx] = node.inputs[0].upper[reference_idx] + else: + best_intermediate_bounds[node.name][0][idx] = torch.max( + best_intermediate_bounds[node.name][0][idx], + node.inputs[0].lower[reference_idx]) + best_intermediate_bounds[node.name][1][idx] = torch.min( + best_intermediate_bounds[node.name][1][idx], + node.inputs[0].upper[reference_idx]) if alpha: - # Each alpha has shape (2, output_shape, batch, *shape) for ReLU. + # Each alpha has shape (2, output_shape, batch, *shape) for act. # For other activation function this can be different. for alpha_m in node.alpha: - if node.alpha_batch_dim == 2: - best_alphas[node.name][alpha_m][:, :, - idx] = node.alpha[alpha_m][:, :, idx] - elif node.alpha_batch_dim == 3: - best_alphas[node.name][alpha_m][:, :, :, - idx] = node.alpha[alpha_m][:, :, :, idx] - else: - raise ValueError( - f'alpha_batch_dim={node.alpha_batch_dim} must be set ' - 'to 2 or 3 in BoundOptimizableActivation') + best_alphas[node.name][alpha_m][:, :, + idx] = node.alpha[alpha_m][:, :, idx] -def _update_best_beta( - self, enable_opt_interm_bounds, betas, optimizable_activations, - best_betas, idx): +def update_best_beta(self: 'BoundedModule', enable_opt_interm_bounds, betas, + best_betas, idx): """ Update best beta by given idx. """ if enable_opt_interm_bounds and betas: - for node in optimizable_activations: - for key in node.sparse_beta.keys(): - best_betas[node.name][key] = ( - node.sparse_beta[key].detach().clone()) + for node in self.splittable_activations: + for node_input in node.inputs: + for key in node_input.sparse_betas.keys(): + best_betas[node_input.name][key] = ( + node_input.sparse_betas[key].val.detach().clone()) if self.cut_used: for gbidx, general_betas in enumerate(self.cut_beta_params): + # FIXME need to check if 'cut' is a node name best_betas['cut'][gbidx] = general_betas.detach().clone() else: + for node in self.nodes_with_beta: + best_betas[node.name][idx] = node.sparse_betas[0].val[idx] if self.cut_used: regular_beta_length = len(betas) - len(self.cut_beta_params) - for beta_idx in range(regular_beta_length): - # regular beta crown betas - best_betas[beta_idx][idx] = betas[beta_idx][idx] for cut_beta_idx in range(len(self.cut_beta_params)): # general cut beta crown general_betas - best_betas[regular_beta_length + cut_beta_idx][:, :, idx, + best_betas['cut'][cut_beta_idx][:, :, idx, :] = betas[regular_beta_length + cut_beta_idx][:, :, idx, :] - else: - for beta_idx in range(len(betas)): - # regular beta crown betas - best_betas[beta_idx][idx] = betas[beta_idx][idx] -def get_optimized_bounds( - self, x=None, aux=None, C=None, IBP=False, forward=False, - method='backward', bound_lower=True, bound_upper=False, +def _get_optimized_bounds( + self: 'BoundedModule', x=None, aux=None, C=None, IBP=False, + forward=False, method='backward', bound_side='lower', reuse_ibp=False, return_A=False, average_A=False, final_node_name=None, - intermediate_layer_bounds=None, reference_bounds=None, + interm_bounds=None, reference_bounds=None, aux_reference_bounds=None, needed_A_dict=None, cutter=None, decision_thresh=None, epsilon_over_decision_thresh=1e-4): """ @@ -417,77 +291,78 @@ def get_optimized_bounds( iteration = opts['iteration'] beta = opts['enable_beta_crown'] alpha = opts['enable_alpha_crown'] - opt_coeffs = opts['opt_coeffs'] - opt_bias = opts['opt_bias'] + apply_output_constraints_to = opts['apply_output_constraints_to'] opt_choice = opts['optimizer'] - single_node_split = opts['single_node_split'] - assert single_node_split is True keep_best = opts['keep_best'] - fix_intermediate_layer_bounds = opts['fix_intermediate_layer_bounds'] - init_alpha = opts['init_alpha'] - lr_alpha = opts['lr_alpha'] - lr_beta = opts['lr_beta'] - lr_cut_beta = opts['lr_cut_beta'] - lr_intermediate_beta = opts['lr_intermediate_beta'] - lr_decay = opts['lr_decay'] - lr_coeffs = opts['lr_coeffs'] + fix_interm_bounds = opts['fix_interm_bounds'] loss_reduction_func = opts['loss_reduction_func'] stop_criterion_func = opts['stop_criterion_func'] use_float64_in_last_iteration = opts['use_float64_in_last_iteration'] early_stop_patience = opts['early_stop_patience'] - intermediate_beta_enabled = opts['intermediate_beta'] start_save_best = opts['start_save_best'] multi_spec_keep_func = opts['multi_spec_keep_func'] + deterministic = opts['deterministic'] enable_opt_interm_bounds = self.bound_opts.get( 'enable_opt_interm_bounds', False) sparse_intermediate_bounds = self.bound_opts.get( 'sparse_intermediate_bounds', False) verbosity = self.bound_opts['verbosity'] - assert bound_lower != bound_upper, ( - 'we can only optimize lower OR upper bound at one time') + if bound_side not in ['lower', 'upper']: + raise ValueError(bound_side) + bound_lower = bound_side == 'lower' + bound_upper = bound_side == 'upper' + assert alpha or beta, ( 'nothing to optimize, use compute bound instead!') if C is not None: self.final_shape = C.size()[:2] self.bound_opts.update({'final_shape': self.final_shape}) - if init_alpha: + if opts['init_alpha']: # TODO: this should set up aux_reference_bounds. - self.init_slope(x, share_slopes=opts['use_shared_alpha'], - method=method, c=C, final_node_name=final_node_name) + self.init_alpha(x, share_alphas=opts['use_shared_alpha'], + method=method, c=C, final_node_name=final_node_name) - # Optimizable activations that are actually used and perturbed - optimizable_activations = [ - n for n in self.optimizable_activations if n.used and n.perturbed] - # Relu node that are actually used - relus = [n for n in self.relus if n.used and n.perturbed] + optimizable_activations = self.get_enabled_opt_act() - alphas, betas, parameters = [], [], [] + alphas, parameters = [], [] dense_coeffs_mask = [] - partial_intermediate_layer_bounds = None - if alpha: best_alphas = _set_alpha( - optimizable_activations, parameters, alphas, lr_alpha) - + optimizable_activations, parameters, alphas, opts['lr_alpha']) if beta: - ret_set_beta = _set_beta( - self, relus, optimizable_activations, single_node_split, - enable_opt_interm_bounds, betas, opt_coeffs, parameters, - lr_coeffs, opt_bias, lr_beta, lr_cut_beta, cutter, - dense_coeffs_mask) - betas, best_betas, coeffs = ret_set_beta[:3] - dense_coeffs_mask, best_coeffs, biases, best_biases = ret_set_beta[3:] + ret_set_beta = self.set_beta( + enable_opt_interm_bounds, parameters, + opts['lr_beta'], opts['lr_cut_beta'], cutter, dense_coeffs_mask) + betas, best_betas, coeffs, dense_coeffs_mask = ret_set_beta[:4] + if apply_output_constraints_to is not None and len(apply_output_constraints_to) > 0: + _set_gammas(self.nodes(), parameters) start = time.time() - if (decision_thresh is not None - and isinstance(decision_thresh, torch.Tensor)): + if isinstance(decision_thresh, torch.Tensor): if decision_thresh.dim() == 1: # add the spec dim to be aligned with compute_bounds return decision_thresh = decision_thresh.unsqueeze(-1) + if opts['pruning_in_iteration']: + if return_A: + raise NotImplementedError( + 'Pruning in iteration optimization does not support ' + 'return A yet. ' + 'Please fix or discard this optimization by setting ' + '--disable_pruning_in_iteration ' + 'or bab: pruning_in_iteration: false') + pruner = OptPruner( + x, threshold=opts['pruning_in_iteration_threshold'], + multi_spec_keep_func=multi_spec_keep_func, + loss_reduction_func=loss_reduction_func, + decision_thresh=decision_thresh, + epsilon_over_decision_thresh=epsilon_over_decision_thresh, + fix_interm_bounds=fix_interm_bounds) + else: + pruner = None if opt_choice == 'adam-autolr': opt = AdamElementLR(parameters) @@ -500,10 +375,7 @@ def get_optimized_bounds( # Create a weight vector to scale learning rate. loss_weight = torch.ones(size=(x[0].size(0),), device=x[0].device) - scheduler = optim.lr_scheduler.ExponentialLR(opt, lr_decay) - - if verbosity > 0 and intermediate_beta_enabled: - self.print_optimized_beta(relus, intermediate_beta_enabled=True) + scheduler = optim.lr_scheduler.ExponentialLR(opt, opts['lr_decay']) # best_intermediate_bounds is linked to aux_reference_bounds! best_intermediate_bounds = {} @@ -516,15 +388,6 @@ def get_optimized_bounds( if aux_reference_bounds is None: aux_reference_bounds = {} - with torch.no_grad(): - pruning_in_iteration = False - # for computing the positive domain ratio - original_size = x[0].shape[0] - preserve_mask = None - - # record the overhead due to extra operations from pruning-in-iteration - pruning_time = 0. - need_grad = True patience = 0 for i in range(iteration): @@ -534,79 +397,66 @@ def get_optimized_bounds( intermediate_constr = None - if not fix_intermediate_layer_bounds: + if not fix_interm_bounds: # If we still optimize all intermediate neurons, we can use - # intermediate_layer_bounds as reference bounds. - reference_bounds = intermediate_layer_bounds + # interm_bounds as reference bounds. + if reference_bounds is None: + reference_bounds = {} + if interm_bounds is not None: + reference_bounds.update(interm_bounds) + interm_bounds = {} if i == iteration - 1: # No grad update needed for the last iteration need_grad = False - if (self.device == 'cuda' and torch.get_default_dtype() == torch.float32 and use_float64_in_last_iteration): - C, x, intermediate_layer_bounds = _to_float64( - self, C, x, aux_reference_bounds, intermediate_layer_bounds) + C, x, interm_bounds = self._to_float64( + C, x, aux_reference_bounds, interm_bounds) + + if pruner: + # we will use last update preserve mask in caller functions to recover + # lA, l, u, etc to full batch size + self.last_update_preserve_mask = pruner.preserve_mask + pruner.cache_full_sized_alpha(optimizable_activations) - # we will use last update preserve mask in caller functions to recover - # lA, l, u, etc to full batch size - self.last_update_preserve_mask = preserve_mask with torch.no_grad() if not need_grad else ExitStack(): # ret is lb, ub or lb, ub, A_dict (if return_A is set to true) - - # argument for intermediate_layer_bounds - # If we set neuron bounds individually, or if we are optimizing - # intermediate layer bounds using beta, we do not set - # intermediate_layer_bounds. - # When intermediate betas are used, we must set - # intermediate_layer_bounds to None because we want to recompute - # all intermediate layer bounds. - if beta and intermediate_beta_enabled: - arg_ilb = partial_intermediate_layer_bounds - elif fix_intermediate_layer_bounds: - arg_ilb = intermediate_layer_bounds - else: - arg_ilb = None - - # argument for aux_reference_bounds - if sparse_intermediate_bounds: - arg_arb = aux_reference_bounds - else: - arg_arb = None - ret = self.compute_bounds( x, aux, C, method=method, IBP=IBP, forward=forward, bound_lower=bound_lower, bound_upper=bound_upper, reuse_ibp=reuse_ibp, return_A=return_A, final_node_name=final_node_name, average_A=average_A, - intermediate_layer_bounds=arg_ilb, + # When intermediate bounds are recomputed, we must set it + # to None + interm_bounds=interm_bounds if fix_interm_bounds else None, # This is the currently tightest interval, which will be used to # pass split constraints when intermediate betas are used. reference_bounds=reference_bounds, # This is the interval used for checking for unstable neurons. - aux_reference_bounds=arg_arb, + aux_reference_bounds=aux_reference_bounds if sparse_intermediate_bounds else None, # These are intermediate layer beta variables and their # corresponding A matrices and biases. intermediate_constr=intermediate_constr, needed_A_dict=needed_A_dict, - update_mask=preserve_mask) - + update_mask=pruner.preserve_mask if pruner else None) ret_l, ret_u = ret[0], ret[1] + if pruner: + pruner.recover_full_sized_alpha(optimizable_activations) + if (self.cut_used and i % cutter.log_interval == 0 and len(self.cut_beta_params) > 0): # betas[-1]: (2(0 lower, 1 upper), spec, batch, num_constrs) if ret_l is not None: - print( - i, 'lb beta sum:', - f'{self.cut_beta_params[-1][0].sum() / ret_l.size(0)},', - f'worst {ret_l.min()}') + print(i, 'lb beta sum:', + f'{self.cut_beta_params[-1][0].sum() / ret_l.size(0)},', + f'worst {ret_l.min()}') if ret_u is not None: - print( - i, 'lb beta sum:', - f'{self.cut_beta_params[-1][1].sum() / ret_u.size(0)},', - f'worst {ret_u.min()}') + print(i, 'lb beta sum:', + f'{self.cut_beta_params[-1][1].sum() / ret_u.size(0)},', + f'worst {ret_u.min()}') if i == 0: # save results at the first iteration @@ -618,9 +468,8 @@ def get_optimized_bounds( ret_0 = ret[0].detach().clone() if bound_lower else ret[1].detach().clone() for node in optimizable_activations: - new_intermediate = [ - node.inputs[0].lower.detach().clone(), - node.inputs[0].upper.detach().clone()] + new_intermediate = [node.inputs[0].lower.detach().clone(), + node.inputs[0].upper.detach().clone()] best_intermediate_bounds[node.name] = new_intermediate if sparse_intermediate_bounds: # Always using the best bounds so far as the reference @@ -640,80 +489,27 @@ def get_optimized_bounds( full_l = l full_ret = ret - # positive domains may already be filtered out, so we use all domains - - # negative domains to compute - if decision_thresh is not None: - if (isinstance(decision_thresh, torch.Tensor) - and decision_thresh.numel() > 1 - and preserve_mask is not None): - if decision_thresh.shape[-1] == 1: - # single spec with pruned domains - negative_domain = ( - ret_l.view(-1) - <= decision_thresh[preserve_mask].view(-1)).sum() - else: - # multiple spec with pruned domains - negative_domain = multi_spec_keep_func( - ret_l <= decision_thresh[preserve_mask]).sum() - else: - if ret_l.shape[-1] == 1: - # single spec - negative_domain = ( - ret_l.view(-1) <= decision_thresh.view(-1)).sum() - else: - # multiple spec - negative_domain = multi_spec_keep_func( - ret_l <= decision_thresh).sum() - positive_domain_num = original_size - negative_domain + if pruner: + (x, C, full_l, full_ret_l, full_ret_u, + full_ret, stop_criterion) = pruner.prune( + x, C, ret_l, ret_u, ret, full_l, full_ret_l, full_ret_u, + full_ret, interm_bounds, aux_reference_bounds, + stop_criterion_func, bound_lower) else: - positive_domain_num = -1 - positive_domain_ratio = float( - positive_domain_num) / float(original_size) - # threshold is 10% by default - next_iter_pruning_in_iteration = ( - opts['pruning_in_iteration'] and decision_thresh is not None - and positive_domain_ratio > opts['pruning_in_iteration_threshold']) - - if pruning_in_iteration: - stime = time.time() - if return_A: - raise Exception( - 'Pruning in iteration optimization does not support ' - 'return A yet. ' - 'Please fix or discard this optimization by setting ' - '--disable_pruning_in_iteration ' - 'or bab: pruning_in_iteration: false') - now_preserve_mask = _get_preserve_mask( - decision_thresh, ret_l, preserve_mask, multi_spec_keep_func) - # prune C - if C is not None and C.shape[0] == x[0].shape[0]: - C = C[now_preserve_mask] # means C is also batch specific - # prune x - x, pre_prune_size = _prune_x(x, now_preserve_mask) - # prune bounds - ret_prune = _prune_bounds_by_mask( - now_preserve_mask, decision_thresh, ret_l, ret_u, ret, - preserve_mask, epsilon_over_decision_thresh, original_size, - loss_reduction_func, beta, intermediate_beta_enabled, - fix_intermediate_layer_bounds, - intermediate_layer_bounds, aux_reference_bounds, - partial_intermediate_layer_bounds, pre_prune_size) - full_l, full_ret_l = ret_prune[:2] - # ret_prune[2] is full_u which is unused - full_ret_u, full_ret, preserve_mask_next = ret_prune[3:] - pruning_time += time.time() - stime + stop_criterion = (stop_criterion_func(full_ret_l) if bound_lower + else stop_criterion_func(-full_ret_u)) loss_ = l if bound_lower else -u - stop_criterion = stop_criterion_func( - full_ret_l) if bound_lower else stop_criterion_func(-full_ret_u) - if (type(stop_criterion) != bool - and stop_criterion.numel() > 1 and pruning_in_iteration): - stop_criterion = stop_criterion[preserve_mask] total_loss = -1 * loss_ + if type(stop_criterion) == bool: loss = total_loss.sum() * (not stop_criterion) else: + assert total_loss.shape == stop_criterion.shape loss = (total_loss * stop_criterion.logical_not()).sum() + # For logging, print the total sum. Otherwise the loss may appear + # to be increasing as more examples are stopped. + loss_sum = total_loss.sum() stop_criterion_final = isinstance( stop_criterion, torch.Tensor) and stop_criterion.all() @@ -728,9 +524,8 @@ def get_optimized_bounds( if (i == iteration - 1 and self.device == 'cuda' and torch.get_default_dtype() == torch.float32 and use_float64_in_last_iteration): - total_loss, x, full_ret = _to_default_dtype( - self, x, total_loss, full_ret, ret, best_intermediate_bounds, - return_A) + total_loss, x, full_ret = self._to_default_dtype( + x, total_loss, full_ret, ret, best_intermediate_bounds, return_A) with torch.no_grad(): # for lb and ub, we update them in every iteration since updating @@ -738,15 +533,13 @@ def get_optimized_bounds( need_update = False if keep_best: if best_ret_u is not None: - ret_upd = _update_best_ret( + best_ret_u, best_ret, need_update, idx_mask, improved_idx = _update_best_ret( full_ret_u, best_ret_u, full_ret, best_ret, need_update, - idx=1) - best_ret_u, best_ret, _, _, need_update = ret_upd + loss_reduction_func, idx=1, deterministic=deterministic) if best_ret_l is not None: - ret_upd = _update_best_ret( + best_ret_l, best_ret, need_update, idx_mask, improved_idx = _update_best_ret( full_ret_l, best_ret_l, full_ret, best_ret, need_update, - idx=0) - best_ret_l, best_ret, _, _, need_update = ret_upd + loss_reduction_func, idx=0, deterministic=deterministic) else: # Not saving the best, just keep the last iteration. if full_ret[0] is not None: @@ -767,48 +560,48 @@ def get_optimized_bounds( # (in case divergence) and second half iterations # or before early stop by either stop_criterion or # early_stop_patience reached - if (i < 1 or i > int(iteration * start_save_best) + if (i < 1 or i > int(iteration * start_save_best) or deterministic or stop_criterion_final or patience == early_stop_patience): # compare with the first iteration results and get improved indexes if bound_lower: - idx_mask, idx = _get_idx_mask(0, full_ret_l, ret_0) + if deterministic: + idx = improved_idx + else: + idx_mask, idx = _get_idx_mask( + 0, full_ret_l, ret_0, loss_reduction_func) ret_0[idx] = full_ret_l[idx] else: - idx_mask, idx = _get_idx_mask(1, full_ret_u, ret_0) + if deterministic: + idx = improved_idx + else: + idx_mask, idx = _get_idx_mask( + 1, full_ret_u, ret_0, loss_reduction_func) ret_0[idx] = full_ret_u[idx] if idx is not None: # for update propose, we condition the idx to update only # on domains preserved - if pruning_in_iteration: - # local sparse index of preserved samples where - # idx == true - local_idx = idx_mask[preserve_mask].nonzero().view(-1) - # idx is global sparse index of preserved samples where - # idx == true - new_idx = torch.zeros_like( - idx_mask, dtype=torch.bool, device=x[0].device) - new_idx[preserve_mask] = idx_mask[preserve_mask] - idx = new_idx.nonzero().view(-1) - reference_idx = local_idx + if pruner: + reference_idx, idx = pruner.prune_idx(idx_mask, idx, x) else: reference_idx = idx _update_optimizable_activations( - optimizable_activations, intermediate_layer_bounds, - fix_intermediate_layer_bounds, best_intermediate_bounds, - reference_idx, idx, alpha, best_alphas) + optimizable_activations, interm_bounds, + fix_interm_bounds, best_intermediate_bounds, + reference_idx, idx, alpha, best_alphas, deterministic) - if beta and single_node_split: - _update_best_beta( - self, enable_opt_interm_bounds, betas, - optimizable_activations, best_betas, idx) + if beta: + self.update_best_beta(enable_opt_interm_bounds, betas, + best_betas, idx) if os.environ.get('AUTOLIRPA_DEBUG_OPT', False): print(f'****** iter [{i}]', - f'loss: {loss.item()}, lr: {opt.param_groups[0]["lr"]}') + f'loss: {loss_sum.item()}, lr: {opt.param_groups[0]["lr"]}', + (' pruning_in_iteration open status: ' + f'{pruner.pruning_in_iteration}') if pruner else '') if stop_criterion_final: print(f'\nall verified at {i}th iter') @@ -820,20 +613,22 @@ def get_optimized_bounds( ' iterations no improvement!') break - current_lr = [param_group['lr'] for param_group in opt.param_groups] + if i != iteration - 1 and not loss.requires_grad: + assert i == 0, (i, iteration) + print('[WARNING] No optimizable parameters found. Will skip optimiziation. ' + 'This happens e.g. if all optimizable layers are freezed or the ' + 'network has no optimizable layers.') + break opt.zero_grad(set_to_none=True) if verbosity > 2: - print( - f'*** iter [{i}]\n', f'loss: {loss.item()}', - total_loss.squeeze().detach().cpu().numpy(), 'lr: ', - current_lr) + current_lr = [param_group['lr'] for param_group in opt.param_groups] + print(f'*** iter [{i}]\n', f'loss: {loss.item()}', + total_loss.squeeze().detach().cpu().numpy(), 'lr: ', + current_lr) if beta: - self.print_optimized_beta(relus, intermediate_beta_enabled) - if opt_coeffs: - for co in coeffs: - print(f'coeff sum: {co.abs().sum():.5g}') + print_optimized_beta(optimizable_activations) if beta and i == 0 and verbosity > 2: breakpoint() @@ -841,6 +636,7 @@ def get_optimized_bounds( # we do not need to update parameters in the last step since the # best result already obtained loss.backward() + # All intermediate variables are not needed at this point. self._clear_and_set_new(None) if opt_choice == 'adam-autolr': @@ -849,7 +645,6 @@ def get_optimized_bounds( opt.step() if beta: - # Clipping to >=0. for b in betas: b.data = (b >= 0) * b.data for dmi in range(len(dense_coeffs_mask)): @@ -860,50 +655,23 @@ def get_optimized_bounds( if alpha: for m in optimizable_activations: - m.clip_alpha_() + m.clip_alpha() + if apply_output_constraints_to is not None and len(apply_output_constraints_to) > 0: + for m in self.nodes(): + m.clip_gammas() scheduler.step() - if pruning_in_iteration: - preserve_mask = preserve_mask_next - if not pruning_in_iteration and next_iter_pruning_in_iteration: - # init preserve_mask etc - preserve_mask = torch.arange( - 0, x[0].shape[0], device=x[0].device, dtype=torch.long) - pruning_in_iteration = True + if pruner: + pruner.next_iter() - if pruning_in_iteration: - # overwrite pruned cells in best_ret by threshold + eps - if return_A: - fin_l, fin_u, fin_A = best_ret - else: - fin_l, fin_u = best_ret - fin_A = None - if fin_l is not None: - new_fin_l = full_ret_l - new_fin_l[preserve_mask] = fin_l[preserve_mask] - fin_l = new_fin_l - if fin_u is not None: - new_fin_u = full_ret_u - new_fin_u[preserve_mask] = fin_u[preserve_mask] - fin_u = new_fin_u - if return_A: - best_ret = (fin_l, fin_u, fin_A) - else: - best_ret = (fin_l, fin_u) + if pruner: + best_ret = pruner.update_best(full_ret_l, full_ret_u, best_ret) if verbosity > 3: breakpoint() if keep_best: - def update_best(dest, src): - for item_dest, item_src in zip(dest, src): - if enable_opt_interm_bounds: - for key in item_dest.keys(): - item_dest[key].data = item_src[key].data - else: - item_dest.data = item_src.data - # Set all variables to their saved best values. with torch.no_grad(): for idx, node in enumerate(optimizable_activations): @@ -912,149 +680,152 @@ def update_best(dest, src): node.alpha = best_alphas[node.name] # Update best intermediate layer bounds only when they are # optimized. If they are already fixed in - # intermediate_layer_bounds, then do nothing. + # interm_bounds, then do nothing. best_intermediate = best_intermediate_bounds[node.name] node.inputs[0].lower.data = best_intermediate[0].data node.inputs[0].upper.data = best_intermediate[1].data - if beta: - if (single_node_split and hasattr(node, 'sparse_beta') - and node.sparse_beta is not None): - if enable_opt_interm_bounds: - for key in node.sparse_beta.keys(): - node.sparse_beta[key].copy_( - best_betas[node.name][key]) - else: - node.sparse_beta.copy_(best_betas[idx]) + if beta: + for node in self.nodes_with_beta: + assert getattr(node, 'sparse_betas', None) is not None + if enable_opt_interm_bounds: + for key in node.sparse_betas.keys(): + node.sparse_betas[key].val.copy_( + best_betas[node.name][key]) else: - update_best(betas, best_betas) - if opt_coeffs: - update_best(coeffs, best_coeffs) - if opt_bias: - update_best(biases, best_biases) + node.sparse_betas[0].val.copy_(best_betas[node.name]) if self.cut_used: - regular_beta_length = len(betas) - len(self.cut_beta_params) for ii in range(len(self.cut_beta_params)): - self.cut_beta_params[ii].data = best_betas[ - regular_beta_length + ii].data + self.cut_beta_params[ii].data = best_betas['cut'][ii].data - if (intermediate_layer_bounds is not None - and not fix_intermediate_layer_bounds): + if interm_bounds is not None and not fix_interm_bounds: for l in self._modules.values(): - if (l.name in intermediate_layer_bounds.keys() + if (l.name in interm_bounds.keys() and hasattr(l, 'lower')): - l.lower = torch.max( - l.lower, intermediate_layer_bounds[l.name][0]) - l.upper = torch.min( - l.upper, intermediate_layer_bounds[l.name][1]) + l.lower = torch.max(l.lower, interm_bounds[l.name][0]) + l.upper = torch.min(l.upper, interm_bounds[l.name][1]) infeasible_neurons = l.lower > l.upper if infeasible_neurons.any(): - print( - f'Infeasibility detected in layer {l.name}.', - infeasible_neurons.sum().item(), - infeasible_neurons.nonzero()[:, 0]) + print(f'Infeasibility detected in layer {l.name}.', + infeasible_neurons.sum().item(), + infeasible_neurons.nonzero()[:, 0]) if verbosity > 0: - if self.cut_used and beta: - print( - 'first 10 best general betas:', - best_betas[-1].view(2, -1)[0][:10], 'sum:', - best_betas[-1][0].sum().item()) if best_ret_l is not None: # FIXME: unify the handling of l and u. - print( - 'best_l after optimization:', - best_ret_l.sum().item(), 'with beta sum per layer:', - [p.sum().item() for p in betas]) + print('best_l after optimization:', best_ret_l.sum().item()) + if beta: + print('beta sum per layer:', [p.sum().item() for p in betas]) print('alpha/beta optimization time:', time.time() - start) for node in optimizable_activations: node.opt_end() - # update pruning ratio - if (opts['pruning_in_iteration'] and decision_thresh is not None - and full_l.numel() > 0): - stime = time.time() - with torch.no_grad(): - if isinstance(decision_thresh, torch.Tensor): - if decision_thresh.shape[-1] == 1: - neg_domain_num = torch.sum( - full_ret_l.view(-1) <= decision_thresh.view(-1)).item() - else: - neg_domain_num = torch.sum(multi_spec_keep_func( - full_ret_l <= decision_thresh)).item() - else: - if full_l.shape[-1] == 1: - neg_domain_num = torch.sum( - full_ret_l.view(-1) <= decision_thresh).item() - else: - neg_domain_num = torch.sum(multi_spec_keep_func( - full_ret_l <= decision_thresh)).item() - now_pruning_ratio = (1.0 - - float(neg_domain_num) / float(full_l.shape[0])) - print('pruning_in_iteration open status:', pruning_in_iteration) - print( - 'ratio of positive domain =', full_l.shape[0] - neg_domain_num, - '/', full_l.numel(), '=', now_pruning_ratio) - pruning_time += time.time() - stime - print('pruning-in-iteration extra time:', pruning_time) + if pruner: + pruner.update_ratio(full_l, full_ret_l) + pruner.clean_full_sized_alpha_cache() + + if os.environ.get('AUTOLIRPA_DEBUG_OPT', False): + print() return best_ret -def init_slope( - self, x, share_slopes=False, method='backward', - c=None, bound_lower=True, bound_upper=True, final_node_name=None, - intermediate_layer_bounds=None, activation_opt_params=None, - skip_bound_compute=False): - for node in self.optimizable_activations: +def init_alpha(self: 'BoundedModule', x, share_alphas=False, method='backward', + c=None, bound_lower=True, bound_upper=True, final_node_name=None, + interm_bounds=None, activation_opt_params=None, + skip_bound_compute=False): + self(*x) # Do a forward pass to set perturbed nodes + final = (self.final_node() if final_node_name is None + else self[final_node_name]) + self._set_used_nodes(final) + + optimizable_activations = self.get_enabled_opt_act() + for node in optimizable_activations: + # TODO(7/6/2023) In the future, we may need to enable alpha sharing + # automatically by consider the size of all the optimizable nodes in the + # graph. For now, only an adhoc check in MatMul is added. + node._all_optimizable_activations = optimizable_activations + # initialize the parameters node.opt_init() - if (not skip_bound_compute or intermediate_layer_bounds is None or + apply_output_constraints_to = ( + self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] + ) + if (not skip_bound_compute or interm_bounds is None or activation_opt_params is None or not all( - [relu.name in activation_opt_params for relu in self.relus])): + [act.name in activation_opt_params + for act in self.optimizable_activations])): skipped = False # if new interval is None, then CROWN interval is not present # in this case, we still need to redo a CROWN pass to initialize # lower/upper with torch.no_grad(): + # We temporarilly deactivate output constraints + self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] = [] l, u = self.compute_bounds( x=x, C=c, method=method, bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, - intermediate_layer_bounds=intermediate_layer_bounds) + interm_bounds=interm_bounds) + self.bound_opts['optimize_bound_args']['apply_output_constraints_to'] = ( + apply_output_constraints_to + ) else: # we skip, but we still would like to figure out the "used", # "perturbed", "backward_from" of each note in the graph skipped = True # this set the "perturbed" property - self._set_input( - *x, intermediate_layer_bounds=intermediate_layer_bounds) - - final = self.final_node( - ) if final_node_name is None else self[final_node_name] - self._set_used_nodes(final) - + self.set_input(*x, interm_bounds=interm_bounds) self.backward_from = {node: [final] for node in self._modules} final_node_name = final_node_name or self.final_name init_intermediate_bounds = {} - for node in self.optimizable_activations: - if not node.used or not node.perturbed: - continue + for node in optimizable_activations: start_nodes = [] if method in ['forward', 'forward+backward']: - start_nodes.append(('_forward', 1, None)) + start_nodes.append(('_forward', 1, None, False)) if method in ['backward', 'forward+backward']: + if ( + apply_output_constraints_to is not None + and len(apply_output_constraints_to) > 0 + ): + input_node = None + for potential_input_node_name in self.input_name: + if type(self[potential_input_node_name]) is BoundInput: + assert input_node is None, 'Only a single input node is supported' + input_node = self[potential_input_node_name] + assert input_node is not None + + backward_from_node = input_node + else: + backward_from_node = node start_nodes += self.get_alpha_crown_start_nodes( - node, c=c, share_slopes=share_slopes, - final_node_name=final_node_name) + node, + c=c, + share_alphas=share_alphas, + final_node_name=final_node_name, + backward_from_node=backward_from_node + ) if skipped: node.restore_optimized_params(activation_opt_params[node.name]) else: node.init_opt_parameters(start_nodes) - init_intermediate_bounds[node.inputs[0].name] = ( - [node.inputs[0].lower.detach(), node.inputs[0].upper.detach()]) + if node in self.splittable_activations: + for i in node.requires_input_bounds: + input_node = node.inputs[i] + if not input_node.perturbed: + continue + init_intermediate_bounds[node.inputs[i].name] = ( + [node.inputs[i].lower.detach(), + node.inputs[i].upper.detach()]) + if ( + apply_output_constraints_to is not None + and len(apply_output_constraints_to) > 0 + and hasattr(self, 'constraints') + ): + for node in self.nodes(): + node.init_gammas(self.constraints.size(0)) if self.bound_opts['verbosity'] >= 1: print('Optimizable variables initialized.') diff --git a/auto_LiRPA/output_constraints.py b/auto_LiRPA/output_constraints.py new file mode 100644 index 0000000..462eb80 --- /dev/null +++ b/auto_LiRPA/output_constraints.py @@ -0,0 +1,282 @@ + +from .utils import * +from .bound_ops import * +from .operators import Bound + +from typing import TYPE_CHECKING, Optional, List +if TYPE_CHECKING: + from .bound_general import BoundedModule + + +def backward_general_with_output_constraint( + self: 'BoundedModule', + bound_node, + C, + start_backporpagation_at_node = None, + bound_lower=True, + bound_upper=True, + average_A=False, + need_A_only=False, + unstable_idx=None, + update_mask=None, + verbose=True, +): + assert start_backporpagation_at_node is None + assert not isinstance(C, str) + + neurons_in_layer = 1 + for d in bound_node.output_shape[1:]: + neurons_in_layer *= d + + # backward_general uses C to compute batch_size, output_dim and output_shape, just like below. + # When output constraints are applied, it will perform a different backpropagation, + # but those variables need to be computed regardless. So we need to retain the original C + # and pass it on to backward_general. If initial_As is set (which it is, if this code here + # is executed), it will not use C for anything else. + orig_C = C + + C, batch_size, output_dim, output_shape = self._preprocess_C(C, bound_node) + device = bound_node.device + + num_constraints = self.constraints.size(0) + + # 1) Linear: Hx + d + # Result is a tensor, <= 0 for all entries if output constraint is satisfied + H = self.constraints.T # (output_neurons, constraints) + d = self.thresholds.squeeze(0) # (constraints) + assert H.ndim == 2 + assert H.size(1) == num_constraints + assert d.ndim == 1 + assert d.size(0) == num_constraints + + linear_Hxd_layer_weight_value = nn.Parameter(H.to(C)) + linear_Hxd_layer_weight = BoundParams( + ori_name="/linear_Hxd_layer_weight", + value=None, + perturbation=None, + ) + linear_Hxd_layer_weight.name = "linear_Hxd_layer_weight" + linear_Hxd_layer_weight.lower = linear_Hxd_layer_weight_value + linear_Hxd_layer_weight.upper = linear_Hxd_layer_weight_value + + linear_Hxd_layer_bias_value = nn.Parameter(d.float().to(device)) + linear_Hxd_layer_bias = BoundParams( + ori_name="/linear_Hxd_layer_bias", + value=None, + perturbation=None, + ) + linear_Hxd_layer_bias.name = "linear_Hxd_layer_bias" + linear_Hxd_layer_bias.lower = linear_Hxd_layer_bias_value + linear_Hxd_layer_bias.upper = linear_Hxd_layer_bias_value + + linear_Hxd_layer = BoundLinear( + attr=None, + inputs=[ + self.final_node(), + linear_Hxd_layer_weight, + linear_Hxd_layer_bias, + ], + output_index=0, + options=self.bound_opts, + ) + linear_Hxd_layer.name = "/linear_Hxd_layer" + linear_Hxd_layer.device = device + linear_Hxd_layer.perturbed = True + linear_Hxd_layer.output_shape = torch.Size([1, num_constraints]) + linear_Hxd_layer.batch_dim = bound_node.batch_dim + + # 2) Gamma + # A seperate gamma per output constraint. All gammas are always positive. + # Note that we're not using a different gamma per neuron in the optimized layer. + # That would be even more precise, but much slower and would require more memory. + gamma_layer_weight = BoundParams( + ori_name="/gamma_layer_weight", + value=None, + perturbation=None, + ) + gamma_layer_weight.name = "gamma_layer_weight" + assert bound_node.gammas.ndim == 2 + assert bound_node.gammas.size(0) == 2 + assert bound_node.gammas.size(1) == num_constraints + gamma_layer_weight.lower = torch.diag(bound_node.gammas[0]) # (5, 5) + gamma_layer_weight.upper = torch.diag(-bound_node.gammas[1]) # (5, 5) + gamma_layer = BoundLinear( + attr=None, + inputs=[linear_Hxd_layer, gamma_layer_weight], + output_index=0, + options=self.bound_opts, + ) + gamma_layer.name = "/gamma_layer" + gamma_layer.device = device + gamma_layer.perturbed = True + gamma_layer.input_shape = linear_Hxd_layer.output_shape + gamma_layer.output_shape = gamma_layer.input_shape + gamma_layer.batch_dim = bound_node.batch_dim + gamma_layer.use_seperate_weights_for_lower_and_upper_bounds = True + + # 3) Sum + # Sum over all constraints. + # In the dualization, if there are multiple output constraints, we have + # min g(x) + gamma_1... + gamma_2... or + # max g(x) - gamma_1... - gamma_2... + # Here, we only compute the sum over all gammas, the addition of g(x) is handled + # further down. + sum_weight_value = nn.Parameter(torch.ones((5,1), device=device)) + sum_weight = BoundParams( + ori_name="/sum_weight", + value=None, + perturbation=None, + ) + sum_weight.name = "sum_weight" + sum_weight.lower = sum_weight_value + sum_weight.upper = sum_weight_value + sum_layer = BoundLinear( + attr=None, + inputs=[gamma_layer, sum_weight], + output_index=0, + options=self.bound_opts, + ) + sum_layer.name = "/sum_layer" + sum_layer.device = device + sum_layer.perturbed = True + sum_layer.input_shape = gamma_layer.output_shape + sum_layer.output_shape = torch.Size([1, 1]) + sum_layer.batch_dim = bound_node.batch_dim + + # 4) Repeat + # One copy per neuron in the layer that should be optimized. + repeat_layer_weight_value = nn.Parameter(torch.ones((1, neurons_in_layer), device=device)) + repeat_layer_weight = BoundParams( + ori_name="/repeat_layer_weight", + value=repeat_layer_weight_value, + perturbation=None, + ) + repeat_layer_weight.name = "repeat_layer_weight" + repeat_layer_weight.lower = repeat_layer_weight_value + repeat_layer_weight.upper = repeat_layer_weight_value + repeat_layer = BoundLinear( + attr=None, + inputs=[sum_layer, repeat_layer_weight], + output_index=0, + options=self.bound_opts, + ) + repeat_layer.name = "/repeat_layer" + repeat_layer.device = device + repeat_layer.perturbed = True + repeat_layer.input_shape = sum_layer.output_shape + repeat_layer.output_shape = torch.Size([1, neurons_in_layer]) + repeat_layer.batch_dim = bound_node.batch_dim + + # 5) Reshape + # To the same shape as the layer that's optimized. + reshape_layer_output_shape = BoundBuffers( + ori_name="/reshape_layer_output_shape", + value = torch.tensor(bound_node.output_shape[1:]), + perturbation=None, + options=self.bound_opts, + ) + reshape_layer_output_shape.name = "reshape_layer_output_shape" + reshape_layer = BoundReshape( + attr=None, + inputs = [repeat_layer, reshape_layer_output_shape], + output_index=0, + options=self.bound_opts, + ) + reshape_layer.name = "/reshape_layer" + reshape_layer.device = device + reshape_layer.perturbed = True + reshape_layer.input_shape = repeat_layer.output_shape + reshape_layer.output_shape = bound_node.output_shape + reshape_layer.batch_dim = bound_node.batch_dim + + # The residual connection that connects the optimized layer and the reshape + # layer from above is not explicitly coded, it's handled implicitly: + # Here, we propagate backwards through 5->4->3->2->1->regular output layer and let + # CROWN handle the propagation from there on backwards to the input layer. + # The other half of the residual connection is implemented by explicitly setting + # the .lA and .uA values of the optimized layer to C. + # This is done via initial_As, initial_lb, initial_ub. + + if True or not isinstance(bound_node, BoundLinear): + if isinstance(C, OneHotC): + batch_size = C.shape[1] + assert C.shape[0] <= C.shape[2] + assert len(C.shape) == 3 + # This is expensive, but Reshape doesn't support OneHotC objects + C = torch.eye(C.shape[2], device=C.device)[C.index].unsqueeze(1).repeat(1, batch_size, 1) + + start_shape = None + lA = C if bound_lower else None + uA = C if bound_upper else None + + # 5) Reshape + A, lower_b, upper_b = reshape_layer.bound_backward( + lA, uA, *reshape_layer.inputs, + start_node=bound_node, unstable_idx=unstable_idx, + start_shape=start_shape) + assert lower_b == 0 + assert upper_b == 0 + lA = A[0][0] + uA = A[0][1] + + # 4) Repeat + A, lower_b, upper_b = repeat_layer.bound_backward( + lA, uA, *repeat_layer.inputs, + start_node=bound_node, unstable_idx=unstable_idx, + start_shape=start_shape) + assert lower_b == 0 + assert upper_b == 0 + lA = A[0][0] + uA = A[0][1] + + # 3) Sum + A, lower_b, upper_b = sum_layer.bound_backward( + lA, uA, *sum_layer.inputs, + start_node=bound_node, unstable_idx=unstable_idx, + start_shape=start_shape) + assert lower_b == 0 + assert upper_b == 0 + lA = A[0][0] + uA = A[0][1] + + # 2) Gamma + A, lower_b, upper_b = gamma_layer.bound_backward( + lA, uA, *gamma_layer.inputs, + start_node=bound_node, unstable_idx=unstable_idx, + start_shape=start_shape) + assert lower_b == 0 + assert upper_b == 0 + lA = A[0][0] + uA = A[0][1] + + # 1) Hx + d + A, lower_b, upper_b = linear_Hxd_layer.bound_backward( + lA, uA, *linear_Hxd_layer.inputs, + start_node=bound_node, unstable_idx=unstable_idx, + start_shape=start_shape) + # lower_b and upper_b are no longer 0, because d wasn't 0. + lA = A[0][0] + uA = A[0][1] + + # This encodes the residual connection. + initial_As = { + self.final_node().name: (lA, uA), + bound_node.name: (C, C), + } + + return self.backward_general( + bound_node = bound_node, + start_backpropagation_at_node = self.final_node(), + C = orig_C, # only used for batch_size, output_dim, output_shape computation + bound_lower = bound_lower, + bound_upper = bound_upper, + average_A = average_A, + need_A_only = need_A_only, + unstable_idx = unstable_idx, + update_mask = update_mask, + verbose = verbose, + apply_output_constraints_to = [], # no nested application + initial_As = initial_As, + initial_lb = lower_b, + initial_ub = upper_b, + ) diff --git a/auto_LiRPA/parse_graph.py b/auto_LiRPA/parse_graph.py index 3137b7c..6325ca6 100644 --- a/auto_LiRPA/parse_graph.py +++ b/auto_LiRPA/parse_graph.py @@ -15,6 +15,14 @@ def get_node_name(node): return node.debugName() +def get_node_attribute(node, attribute_name): + if hasattr(torch.onnx.symbolic_helper, '_node_get'): + # Pytorch >= 1.13. + return torch.onnx.symbolic_helper._node_get(node, attribute_name) + else: + # Pytorch <= 1.12. This will call _node_getitem in torch.onnx.utils. + return node[attribute_name] + def parse_graph(graph, inputs, params): input_all = [] input_used = [] @@ -42,10 +50,9 @@ def name_with_scope(node): nodesOP = [] for n in graph.nodes(): - attrs = {k: n[k] for k in n.attributeNames()} + attrs = {k: get_node_attribute(n, k) for k in n.attributeNames()} n_inputs = [name_with_scope(i) for i in n.inputs()] for i, out in enumerate(list(n.outputs())): - nodesOP.append(Node(**{'name': name_with_scope(out), 'op': n.kind(), 'inputs': n_inputs, @@ -87,7 +94,7 @@ def name_with_scope(node): perturbation = inputs_and_params[i][1].ptb else: perturbation = None - if n.type().sizes() != list(inputs_and_params[i][1].size()): + if i > 0 and n.type().sizes() != list(inputs_and_params[i][1].size()): raise RuntimeError("Input tensor shapes do not much: {} != {}".format( n.type().sizes(), list(inputs_and_params[i][1].size()))) nodesIn[i] = Node(**{'name': name_with_scope(n), @@ -114,7 +121,7 @@ def _get_jit_params(module, param_exclude, param_include): new_state_dict = OrderedDict() for k, v in state_dict.items(): if param_exclude is not None and param_exclude.match(k) is not None: - print('\nremove input element {} from nodesIn\n'.format(k)) + print(f'\nremove input element {k} from nodesIn\n') continue if param_include is not None and param_include.match(k) is None: continue @@ -124,9 +131,9 @@ def _get_jit_params(module, param_exclude, param_include): return params -"""Construct a template for the module output with `None` representing places -to be filled with tensor results""" def get_output_template(out): + """Construct a template for the module output with `None` representing places + to be filled with tensor results""" if isinstance(out, torch.Tensor): return None elif isinstance(out, list): @@ -145,9 +152,16 @@ def parse_module(module, inputs, param_exclude=".*AuxLogits.*", param_include=No params = _get_jit_params(module, param_exclude=param_exclude, param_include=param_include) trace, out = torch.jit._get_trace_graph(module, inputs) _set_opset_version(12) + + # Assuming that the first node in the graph is the primary input node. + # It must have a batch dimension. + primary_input = get_node_name(next(iter(trace.inputs()))) trace_graph = _optimize_graph( - trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, params_dict={}) - logger.debug('trace_graph: {}'.format(trace_graph)) + trace, torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, + params_dict={}, + input_names=[primary_input], + dynamic_axes={primary_input: {0: 'batch'}}) + logger.debug('trace_graph: %s', trace_graph) if int(os.environ.get('AUTOLIRPA_DEBUG_GRAPH', 0)) > 0: print("Graph before ONNX convertion:") diff --git a/auto_LiRPA/patches.py b/auto_LiRPA/patches.py index f79b360..8f6d7e3 100644 --- a/auto_LiRPA/patches.py +++ b/auto_LiRPA/patches.py @@ -39,31 +39,56 @@ def insert_zeros(image, s): return matrix -def remove_zeros(image, s): +def remove_zeros(image, s, remove_zero_start_idx=(0,0)): if s <= 0: return image matrix_stride = image.stride() + storage_offset = image.storage_offset() return torch.as_strided(image, [ # Shape of the output matrix. *image.shape[:-2], - (image.size(-2) + 1) // 2, # H (without zeros) - (image.size(-1) + 1) // 2, # W (without zeros) + (image.size(-2) - remove_zero_start_idx[-2] + (s + 1) - 1) // (s + 1), # H (without zeros) + (image.size(-1) - remove_zero_start_idx[-1] + (s + 1) - 1) // (s + 1), # W (without zeros) ], [ # Stride of the output matrix. *matrix_stride[:-2], matrix_stride[-2] * (s + 1), # Move s+1 rows. matrix_stride[-1] * (s + 1), # Move s+1 pixels. - ]) + ], + storage_offset + matrix_stride[-2] * remove_zero_start_idx[-2] + matrix_stride[-1] * remove_zero_start_idx[-1] + ) def unify_shape(shape): - """Convert shapes to 4-tuple.""" + """ + Convert shapes to 4-tuple: (left, right, top, bottom). + """ if shape is not None: if isinstance(shape, int): + # Same on all four directions. shape = (shape, shape, shape, shape) if len(shape) == 2: + # (height direction, width direction). shape = (shape[1], shape[1], shape[0], shape[0]) assert len(shape) == 4 + # Returned: (left, right, top, bottom). + return shape + + +def simplify_shape(shape): + """ + Convert shapes to 2-tuple or a single number. + Used to avoid extra padding operation because the padding + operation in F.conv2d is not general enough. + """ + if len(shape) == 4: + # 4-tuple: (left, right, top, bottom). + if shape[0] == shape[1] and shape[2] == shape[3]: + shape = (shape[2], shape[0]) + if len(shape) == 2: + # 2-tuple: (height direction, width direction). + if shape[0] == shape[1]: + shape = shape[0] return shape @@ -146,6 +171,14 @@ def __add__(self, other): output_shape=self.output_shape, unstable_idx=self.unstable_idx) return A1_matrix.transpose(0, 1) + matrix + def __str__(self): + return ( + f"Patches(stride={self.stride}, padding={self.padding}, " + f"output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}, " + f"kernel_shape={list(self.patches.shape)}, input_shape={self.input_shape}, " + f"output_shape={self.output_shape}, unstable_idx={type(self.unstable_idx)})" + ) + @property def device(self): if self.patches is not None: @@ -164,12 +197,15 @@ def create_similar(self, patches=None, stride=None, padding=None, identity=None, Create a new Patches object with new patches weights, and keep other properties the same. """ new_patches = self.patches if patches is None else patches + new_identity = self.identity if identity is None else identity + if new_identity and (new_patches is not None): + raise ValueError("Identity Patches should have .patches property set to 0.") return Patches( new_patches, stride=self.stride if stride is None else stride, padding=self.padding if padding is None else padding, shape=new_patches.shape, - identity=self.identity if identity is None else identity, + identity=new_identity, unstable_idx=self.unstable_idx if unstable_idx is None else unstable_idx, output_shape=self.output_shape if output_shape is None else output_shape, inserted_zeros=self.inserted_zeros if inserted_zeros is None else inserted_zeros, @@ -178,23 +214,35 @@ def create_similar(self, patches=None, stride=None, padding=None, identity=None, ) def to_matrix(self, input_shape): - assert self.inserted_zeros == 0 assert not is_shape_used(self.output_padding) - return patches_to_matrix(self.patches, input_shape, self.stride, self.padding, self.output_shape, self.unstable_idx) + return patches_to_matrix( + self.patches, input_shape, self.stride, self.padding, + self.output_shape, self.unstable_idx, self.inserted_zeros + ) def simplify(self): """Merge stride and inserted_zeros; if they are the same they can cancel out.""" stride = [self.stride, self.stride] if isinstance(self.stride, int) else self.stride - if self.inserted_zeros > 0 and self.inserted_zeros + 1 == stride[0] and stride[0] == stride[1]: + if (self.inserted_zeros > 0 and self.inserted_zeros + 1 == stride[0] and + stride[0] == stride[1] and (self.patches.size(-1) % stride[1]) == 0 and (self.patches.size(-2) % stride[0]) == 0): # print(f'before simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}') full_stride = [stride[1], stride[1], stride[0], stride[0]] # output_padding = tuple(p // s for p, s in zip(output_padding, full_stride)) - self.padding = tuple(p // s - o for p, s, o in zip(self.padding, full_stride, unify_shape(self.output_padding))) - self.patches = remove_zeros(self.patches, self.inserted_zeros) - self.stride = 1 - self.inserted_zeros = 0 - self.output_padding = 0 - # print(f'after simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}') + padding = unify_shape(self.padding) + # since inserted_zero will not put zeros to both end, like [x 0 0 x 0 0 x] instead of [x 0 0 x 0 0 x 0 0] + # when computing the simplified padding, we should view (inserted_zeros-1) padding entries from one end side + # as part of the inserted_zero matrices (i.e., "consumed") + consumed_padding = (padding[0], padding[1] - (stride[1] - 1), padding[2], padding[3] - (stride[0] - 1)) + tentative_padding = tuple(p // s - o for p, s, o in zip(consumed_padding, full_stride, unify_shape(self.output_padding))) + # negative padding is inconvenient + if all([p >= 0 for p in tentative_padding]): + remove_zero_start_idx = (padding[2] % stride[0], padding[0] % stride[1]) + self.padding = tentative_padding + self.patches = remove_zeros(self.patches, self.inserted_zeros, remove_zero_start_idx=remove_zero_start_idx) + self.stride = 1 + self.inserted_zeros = 0 + self.output_padding = 0 + # print(f'after simplify: patches={self.patches.size()} padding={self.padding}, stride={self.stride}, output_padding={self.output_padding}, inserted_zeros={self.inserted_zeros}') def matmul(self, input, patch_abs=False, input_shape=None): """ @@ -269,7 +317,8 @@ def compute_patches_stride_padding(input_shape, patches_padding, patches_stride, return new_padding, new_stride, new_output_padding -def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, unstable_idx=None): +def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, + unstable_idx=None, inserted_zeros=0): """Converting a Patches piece into a full dense matrix.""" if type(padding) == int: padding = (padding, padding, padding, padding) @@ -292,6 +341,9 @@ def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, u input_channel, kernel_x, kernel_y = pieces.shape[-3:] input_x, input_y = input_shape[-2:] + if inserted_zeros > 0: + input_x, input_y = (input_x - 1) * (inserted_zeros + 1) + 1, (input_y - 1) * (inserted_zeros + 1) + 1 + if unstable_idx is None: # Fix all patches in a full A matrix. A_matrix = torch.zeros(batch_size, output_channel, output_x, output_y, input_channel, (input_x + padding[2] + padding[3]) * (input_y + padding[0] + padding[1]), device=pieces.device, dtype=pieces.dtype) @@ -325,6 +377,9 @@ def patches_to_matrix(pieces, input_shape, stride, padding, output_shape=None, u A_matrix = A_matrix[:,:,:,padding[2]:input_x + padding[2],padding[0]:input_y + padding[0]] + if inserted_zeros > 0: + A_matrix = A_matrix[:,:,:, ::(inserted_zeros+1), ::(inserted_zeros+1)] + return A_matrix @@ -394,7 +449,6 @@ def inplace_unfold(image, kernel_size, stride=1, padding=0, inserted_zeros=0, ou # Output shape is (batch_size, patches_h, patches_w, channel, kernel_height, kernel_width) if sum(output_padding) > 0: output_padding = tuple(p if p > 0 else None for p in output_padding) - output_padding = (output_padding) matrix_strided = matrix_strided[:, output_padding[2]:-output_padding[3] if output_padding[3] is not None else None, output_padding[0]:-output_padding[1] if output_padding[1] is not None else None, :, :, :] return matrix_strided @@ -431,9 +485,11 @@ def maybe_unfold_patches(d_tensor, last_A, alpha_lookup_idx=None): d_unfolded_r = d_unfolded.view(*d_shape[:-3], *d_unfolded.shape[1:]) if last_A.unstable_idx is not None: # Here we have d for all output neurons, but we only need to select unstable ones. - if d_unfolded_r.size(0) == 1: + if d_unfolded_r.size(0) == 1 and alpha_lookup_idx is None: # Shared alpha, spasre alpha should not be used. - assert alpha_lookup_idx is None + # Note: only d_unfolded_r.size(0) == 1 cannot judge that it is a shared alpha, + # since the activation may have no unstable neuron at all so + # the first dim = 1 + # unstable neuron still equals to 1 if len(last_A.unstable_idx) == 3: # Broadcast the spec shape, so only need to select the rest dimensions. # Change shape to (out_h, out_w, batch, in_c, H, W) or (out_h, out_w, in_c, H, W). @@ -512,3 +568,34 @@ def maybe_unfold_patches(d_tensor, last_A, alpha_lookup_idx=None): # the out_h, out_w dimension and out_c = 1 (sepc). We added 1s for the out_h, out_w dimensions. d_unfolded_r = d_unfolded_r.unsqueeze(2).unsqueeze(-4) return d_unfolded_r + +def create_valid_mask(output_shape, device, dtype, kernel_size, stride, inserted_zeros, padding, output_padding, + unstable_idx=None): + """ + Create a 0-1 mask of patch pieces shape (except batch dim), + where 1 indicates the cells corresponding to valid image pixels + Can be used to mask out unused A cells + :return: tensor of batch pieces shape, containing the binary mask + """ + one_d = torch.ones( + tuple(1 for i in output_shape[1:]), + device=device, dtype=dtype + ).expand(output_shape[1:]) + # Add batch dimension. + one_d = one_d.unsqueeze(0) + # After unfolding, the shape is (1, out_h, out_w, in_c, h, w) + one_d_unfolded = inplace_unfold( + one_d, kernel_size=kernel_size, + stride=stride, padding=padding, + inserted_zeros=inserted_zeros, + output_padding=output_padding) + if unstable_idx is not None: + # Move out_h, out_w dimension to the front for easier selection. + ans = one_d_unfolded.permute(1, 2, 0, 3, 4, 5) + # for sparse patches the shape is (unstable_size, batch, in_c, h, w). + # Batch size is 1 so no need to select here. + ans = ans[unstable_idx[1], unstable_idx[2]] + else: + # Append the spec dimension. + ans = one_d_unfolded.unsqueeze(0) + return ans diff --git a/auto_LiRPA/perturbations.py b/auto_LiRPA/perturbations.py index 2c4955f..668caae 100644 --- a/auto_LiRPA/perturbations.py +++ b/auto_LiRPA/perturbations.py @@ -1,5 +1,6 @@ import json import math +import os import numpy as np import torch from .utils import logger, eyeC @@ -127,10 +128,11 @@ def __repr__(self): return 'PerturbationLpNorm(norm=0, eps={})'.format(self.eps) -"""Perturbation constrained by the L_p norm.""" class PerturbationLpNorm(Perturbation): - def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None): + """Perturbation constrained by the L_p norm.""" + def __init__(self, eps=0, norm=np.inf, x_L=None, x_U=None, eps_min=0): self.eps = eps + self.eps_min = eps_min self.norm = norm self.dual_norm = 1 if (norm == np.inf) else (np.float64(1.0) / (1 - 1.0 / self.norm)) self.x_L = x_L @@ -149,27 +151,12 @@ def get_input_bounds(self, x, A): x_U = x + self.eps if self.x_U is None else self.x_U return x_L, x_U - # If A is an identity matrix, we will handle specially. - def concretize_matrix(self, x, A, sign, extra_constr): + def concretize_matrix(self, x, A, sign): + # If A is an identity matrix, we will handle specially. if not isinstance(A, eyeC): # A has (Batch, spec, *input_size). For intermediate neurons, spec is *neuron_size. A = A.reshape(A.shape[0], A.shape[1], -1) - if extra_constr is not None: - # For each neuron, we have a beta, so beta size is (Batch, *neuron_size, n_beta) (in A, spec is *neuron_size). - # For intermediate layer neurons, A has *neuron_size specifications. - beta = extra_constr['beta'] - beta = beta.view(beta.size(0), -1, beta.size(-1)) - # coeffs are linear relationships between split neurons and x. They have size (batch, n_beta, *input_size), and unreated to neuron_size. - beta_coeffs = extra_constr['coeffs'] - beta_coeffs = beta_coeffs.view(beta_coeffs.size(0), beta_coeffs.size(1), -1) - # biases are added for each batch each spec, size is (batch, n_beta), and unrelated to neuron_size. - beta_bias = extra_constr['bias'] - # Merge beta into extra A and bias. Extra A has size (batch, spec, *input_size). For intermediate neurons, spec is *neuron_size. - extra_A = torch.einsum('ijk,ikl->ijl', beta, beta_coeffs) - # Merge beta into the bias term. Output has size (batch, spec). - extra_bias = torch.einsum('ijk,ik->ij', beta, beta_bias) - if self.norm == np.inf: # For Linfinity distortion, when an upper and lower bound is given, we use them instead of eps. x_L, x_U = self.get_input_bounds(x, A) @@ -179,21 +166,11 @@ def concretize_matrix(self, x, A, sign, extra_constr): center = (x_ub + x_lb) / 2.0 diff = (x_ub - x_lb) / 2.0 if not isinstance(A, eyeC): - if extra_constr is not None: - # Extra linear and bias terms from constraints. - print( - f'A extra: {(sign * extra_A).abs().sum().item()}, ' - f'b extra: {(sign * extra_bias).abs().sum().item()}') - A = A - sign * extra_A - bound = A.matmul(center) - sign * extra_bias.unsqueeze(-1) + sign * A.abs().matmul(diff) - else: - bound = A.matmul(center) + sign * A.abs().matmul(diff) + bound = A.matmul(center) + sign * A.abs().matmul(diff) else: - assert extra_constr is None # A is an identity matrix. No need to do this matmul. bound = center + sign * diff else: - assert extra_constr is None x = x.reshape(x.shape[0], -1, 1) if not isinstance(A, eyeC): # Find the upper and lower bounds via dual norm. @@ -205,7 +182,7 @@ def concretize_matrix(self, x, A, sign, extra_constr): bound = bound.squeeze(-1) return bound - def concretize_patches(self, x, A, sign, extra_constr): + def concretize_patches(self, x, A, sign): if self.norm == np.inf: x_L, x_U = self.get_input_bounds(x, A) @@ -224,12 +201,7 @@ def concretize_patches(self, x, A, sign, extra_constr): bound -= bound_diff else: raise ValueError("Unsupported Sign") - - # The extra bias term from beta term. - if extra_constr is not None: - bound += extra_constr else: - assert extra_constr is None # A is an identity matrix. No need to do this matmul. bound = center + sign * diff return bound @@ -237,33 +209,39 @@ def concretize_patches(self, x, A, sign, extra_constr): input_shape = x.shape if not A.identity: # Find the upper and lower bounds via dual norm. - # matrix has shape (batch_size, out_c * out_h * out_w, input_c, input_h, input_w) or (batch_size, unstable_size, input_c, input_h, input_w) - matrix = patches_to_matrix(A.patches, input_shape, A.stride, A.padding, A.output_shape, A.unstable_idx) - # Note that we should avoid reshape the matrix. Due to padding, matrix cannot be reshaped without copying. + # matrix has shape + # (batch_size, out_c * out_h * out_w, input_c, input_h, input_w) + # or (batch_size, unstable_size, input_c, input_h, input_w) + matrix = patches_to_matrix( + A.patches, input_shape, A.stride, A.padding, A.output_shape, + A.unstable_idx) + # Note that we should avoid reshape the matrix. + # Due to padding, matrix cannot be reshaped without copying. deviation = matrix.norm(p=self.dual_norm, dim=(-3,-2,-1)) * self.eps # Bound has shape (batch, out_c * out_h * out_w) or (batch, unstable_size). bound = torch.einsum('bschw,bchw->bs', matrix, x) + sign * deviation if A.unstable_idx is None: # Reshape to (batch, out_c, out_h, out_w). - bound = bound.view(matrix.size(0), A.patches.size(0), A.patches.size(2), A.patches.size(3)) + bound = bound.view(matrix.size(0), A.patches.size(0), + A.patches.size(2), A.patches.size(3)) else: # A is an identity matrix. Its norm is all 1. bound = x + sign * self.eps return bound - """Given an variable x and its bound matrix A, compute worst case bound according to Lp norm.""" - def concretize(self, x, A, sign=-1, aux=None, extra_constr=None): + def concretize(self, x, A, sign=-1, aux=None): + """Given an variable x and its bound matrix A, compute worst case bound according to Lp norm.""" if A is None: return None if isinstance(A, eyeC) or isinstance(A, torch.Tensor): - return self.concretize_matrix(x, A, sign, extra_constr) + return self.concretize_matrix(x, A, sign) elif isinstance(A, Patches): - return self.concretize_patches(x, A, sign, extra_constr) + return self.concretize_patches(x, A, sign) else: raise NotImplementedError() - """ Sparse Linf perturbation where only a few dimensions are actually perturbed""" def init_sparse_linf(self, x, x_L, x_U): + """ Sparse Linf perturbation where only a few dimensions are actually perturbed""" self.sparse = True batch_size = x_L.shape[0] perturbed = (x_U > x_L).int() @@ -292,12 +270,19 @@ def init(self, x, aux=None, forward=False): x_L = x - self.eps if self.x_L is None else self.x_L x_U = x + self.eps if self.x_U is None else self.x_U else: - # For other norms, we pass in the BoundedTensor objects directly. - x_L = x_U = x + if int(os.environ.get('AUTOLIRPA_L2_DEBUG', 0)) == 1: + # FIXME Experimental code. Need to change the IBP code also. + x_L = x - self.eps if self.x_L is None else self.x_L + x_U = x + self.eps if self.x_U is None else self.x_U + else: + # FIXME This causes confusing lower bound and upper bound + # For other norms, we pass in the BoundedTensor objects directly. + x_L = x_U = x if not forward: return LinearBound( None, None, None, None, x_L, x_U), x, None - if self.norm == np.inf and x_L.numel() > 1 and (x_L == x_U).sum() > 0.5 * x_L.numel(): + if (self.norm == np.inf and x_L.numel() > 1 + and (x_L == x_U).sum() > 0.5 * x_L.numel()): return self.init_sparse_linf(x, x_L, x_U) batch_size = x.shape[0] @@ -311,11 +296,12 @@ def init(self, x, aux=None, forward=False): def __repr__(self): if self.norm == np.inf: if self.x_L is None and self.x_U is None: - return 'PerturbationLpNorm(norm=inf, eps={})'.format(self.eps) + return f'PerturbationLpNorm(norm=inf, eps={self.eps})' else: - return 'PerturbationLpNorm(norm=inf, eps={}, x_L={}, x_U={})'.format(self.eps, self.x_L, self.x_U) + return f'PerturbationLpNorm(norm=inf, eps={self.eps}, x_L={self.x_L}, x_U={self.x_U})' else: - return 'PerturbationLpNorm(norm={}, eps={})'.format(self.norm, self.eps) + return f'PerturbationLpNorm(norm={self.norm}, eps={self.eps})' + class PerturbationSynonym(Perturbation): def __init__(self, budget, eps=1.0, use_simple=False): @@ -328,8 +314,8 @@ def __init__(self, budget, eps=1.0, use_simple=False): self.train = False def __repr__(self): - return 'perturbation(Synonym-based word substitution budget={}, eps={})'.format( - self.budget, self.eps) + return (f'perturbation(Synonym-based word substitution ' + f'budget={self.budget}, eps={self.eps})') def _load_synonyms(self, path='data/synonyms.json'): with open(path) as file: @@ -347,7 +333,7 @@ def concretize(self, x, A, sign, aux): dim_out = A.shape[1] max_num_cand = x_rep.shape[2] - mask_rep = torch.tensor(can_be_replaced, dtype=torch.float32, device=A.device) + mask_rep = torch.tensor(can_be_replaced, dtype=torch.get_default_dtype(), device=A.device) num_pos = int(np.max(np.sum(can_be_replaced, axis=-1))) update_A = A.shape[-1] > num_pos * dim_word @@ -456,7 +442,6 @@ def init(self, x, aux=None, forward=False): eye = torch.eye(dim_word).to(x.device) lw = torch.zeros(batch_size, dim, length, dim_word).to(x.device) lb = torch.zeros_like(x).to(x.device) - x_new = [] word_embeddings = self.model.word_embeddings.weight vocab = self.model.vocab x_rep = [[[] for i in range(length)] for t in range(batch_size)] @@ -496,7 +481,7 @@ def init(self, x, aux=None, forward=False): x_rep_ += x_rep[t][i] + [zeros] * (max_num_cand - len(x_rep[t][i])) mask += [1] * len(x_rep[t][i]) + [0] * (max_num_cand - len(x_rep[t][i])) x_rep_ = torch.cat(x_rep_).reshape(batch_size, length, max_num_cand, dim_word) - mask = torch.tensor(mask, dtype=torch.float32, device=x.device)\ + mask = torch.tensor(mask, dtype=torch.get_default_dtype(), device=x.device)\ .reshape(batch_size, length, max_num_cand) x_rep_ = x_rep_ * self.eps + x.unsqueeze(2) * (1 - self.eps) @@ -509,7 +494,7 @@ def init(self, x, aux=None, forward=False): return LinearBound(lw, lb, uw, ub, lower, upper), x, (x_rep_, mask, can_be_replaced) def _build_substitution(self, batch): - for t, example in enumerate(batch): + for example in batch: if not 'candidates' in example or example['candidates'] is None: candidates = [] tokens = example['sentence'].strip().lower().split(' ') @@ -523,3 +508,4 @@ def _build_substitution(self, batch): _cand = [tokens[i]] + _cand candidates.append(_cand) example['candidates'] = candidates + diff --git a/auto_LiRPA/solver_module.py b/auto_LiRPA/solver_module.py index b8e3d22..e110705 100644 --- a/auto_LiRPA/solver_module.py +++ b/auto_LiRPA/solver_module.py @@ -1,7 +1,11 @@ from .bound_ops import * +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from .bound_general import BoundedModule -def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None, + +def build_solver_module(self: 'BoundedModule', x=None, C=None, interm_bounds=None, final_node_name=None, model_type="mip", solver_pkg="gurobi"): r"""build lp/mip solvers in general graph. @@ -11,7 +15,7 @@ def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None, C (Tensor): The specification matrix that can map the output of the model with an additional linear layer. This is usually used for maping the logits output of the model to classification margins. - intermediate_layer_bounds: if specified, will replace existing intermediate layer bounds. + interm_bounds: if specified, will replace existing intermediate layer bounds. Otherwise we reuse exising intermediate bounds. final_node_name (String): the name for the target layer to optimize @@ -21,7 +25,7 @@ def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None, Returns: output vars (list): a list of final nodes to optimize """ - # self.root_name: list of root node name + # self.root_names: list of root node name # self.final_name: list of output node name # self.final_node: output module # .input: a list of input modules of this layer module @@ -30,24 +34,23 @@ def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None, # if last layer we need to be careful with: # C: specification matrix # .is_input_perturbed(1) - if x is not None: - assert intermediate_layer_bounds is not None + assert interm_bounds is not None # Set the model to use new intermediate layer bounds, ignore the original ones. - self._set_input(x, intermediate_layer_bounds=intermediate_layer_bounds) + self.set_input(x, interm_bounds=interm_bounds) - root = [self[name] for name in self.root_name] + roots = [self[name] for name in self.root_names] # create interval ranges for input and other weight parameters - for i in range(len(root)): - value = root[i].forward() + for i in range(len(roots)): + value = roots[i].forward() # if isinstance(root[i], BoundInput) and not isinstance(root[i], BoundParams): - if type(root[i]) is BoundInput: + if type(roots[i]) is BoundInput: # create input vars for gurobi self.model - inp_gurobi_vars = self._build_solver_input(root[i]) + inp_gurobi_vars = self._build_solver_input(roots[i]) else: # regular weights - root[i].solver_vars = value + roots[i].solver_vars = value final = self.final_node() if final_node_name is None else self[final_node_name] @@ -58,10 +61,11 @@ def build_solver_module(self, x=None, C=None, intermediate_layer_bounds=None, return final.solver_vars -def _build_solver_general(self, node, C=None, model_type="mip", solver_pkg="gurobi"): +def _build_solver_general(self: 'BoundedModule', node: Bound, C=None, model_type="mip", + solver_pkg="gurobi"): if not hasattr(node, 'solver_vars'): for n in node.inputs: - self._build_solver_general(n, C=C, model_type=model_type, solver_pkg=solver_pkg) + self._build_solver_general(n, C=C, model_type=model_type) inp = [n_pre.solver_vars for n_pre in node.inputs] # print(node, node.inputs) if C is not None and isinstance(node, BoundLinear) and\ @@ -70,15 +74,20 @@ def _build_solver_general(self, node, C=None, model_type="mip", solver_pkg="guro # merge the last BoundLinear node with the specification, # available when weights of this layer are not perturbed solver_vars = node.build_solver(*inp, model=self.model, C=C, - model_type=model_type, solver_pkg=solver_pkg) + model_type=model_type, solver_pkg=solver_pkg) else: solver_vars = node.build_solver(*inp, model=self.model, C=None, model_type=model_type, solver_pkg=solver_pkg) # just return output node gurobi vars return solver_vars +def _reset_solver_vars(self: 'BoundedModule', node: Bound): + if hasattr(node, 'solver_vars'): + del node.solver_vars + for n in node.inputs: + self._reset_solver_vars(n) -def _build_solver_input(self, node): +def _build_solver_input(self: 'BoundedModule', node): ## Do the input layer, which is a special case assert isinstance(node, BoundInput) assert node.perturbation is not None @@ -86,6 +95,8 @@ def _build_solver_input(self, node): inp_gurobi_vars = [] # zero var will be shared within the solver model zero_var = self.model.addVar(lb=0, ub=0, obj=0, vtype=grb.GRB.CONTINUOUS, name='zero') + one_var = self.model.addVar(lb=1, ub=1, obj=0, vtype=grb.GRB.CONTINUOUS, name='one') + neg_one_var = self.model.addVar(lb=-1, ub=-1, obj=0, vtype=grb.GRB.CONTINUOUS, name='neg_one') x_L = node.value - node.perturbation.eps if node.perturbation.x_L is None else node.perturbation.x_L x_U = node.value + node.perturbation.eps if node.perturbation.x_U is None else node.perturbation.x_U x_L = x_L.squeeze(0) diff --git a/auto_LiRPA/utils.py b/auto_LiRPA/utils.py index d08abcb..3a83d14 100644 --- a/auto_LiRPA/utils.py +++ b/auto_LiRPA/utils.py @@ -2,12 +2,10 @@ import time import torch import torch.nn as nn -import torch.nn.functional as F import os import sys import appdirs from collections import defaultdict, namedtuple -from collections.abc import Sequence from functools import reduce import operator import warnings @@ -28,6 +26,19 @@ eyeC = namedtuple('eyeC', 'shape device') OneHotC = namedtuple('OneHotC', 'shape device index coeffs') +def onehotc_to_dense(one_hot_c: OneHotC, dtype: torch.dtype) -> torch.Tensor: + shape = one_hot_c.shape # [spec, batch, C, H, W] + dim = int(prod(shape[2:])) + dense = torch.zeros( + size=(shape[0], shape[1], dim), device=one_hot_c.device, dtype=dtype) + # one_hot_c.index has size (spec, batch), its values are the index of the one-hot non-zero elements in A. + # one_hot_c.coeffs is the value of the non-zero element. + dense = torch.scatter( + dense, dim=2, index=one_hot_c.index.unsqueeze(-1), + src=one_hot_c.coeffs.unsqueeze(-1)) + dense = dense.view(shape[0], shape[1], *shape[2:]) + return dense + # Benchmarking mode disable some expensive assertions. Benchmarking = True @@ -54,35 +65,42 @@ def reduction_str2func(reduction_func): else: return reduction_func -def stop_criterion_sum(threshold=0): - return lambda x: (x.sum(1, keepdim=True) > threshold) - -def stop_criterion_mean(threshold=0): - return lambda x: (x.mean(1, keepdim=True) > threshold) +def stop_criterion_placeholder(threshold=0): + return lambda x: RuntimeError("BUG: bound optimization stop criterion not specified.") def stop_criterion_min(threshold=0): return lambda x: (x.min(1, keepdim=True).values > threshold) +def stop_criterion_all(threshold=0): + # The dimension of x should be (batch, spec). The spec dimension + # This was used in the incomplete verifier, where the spec dimension can + # present statements in an OR clause. + return lambda x: (x > threshold).all(dim=1, keepdim=True) + def stop_criterion_max(threshold=0): return lambda x: (x.max(1, keepdim=True).values > threshold) def stop_criterion_batch(threshold=0): # may unexpected broadcast, pay attention to the shape of threshold # x shape: batch, number_bounds; threshold shape: batch, number_bounds - # print('threshold', threshold.shape) return lambda x: (x > threshold) def stop_criterion_batch_any(threshold=0): + """If any spec >= rhs, then this sample can be stopped; + if all samples can be stopped, stop = True, o.w., False. + """ # may unexpected broadcast, pay attention to the shape of threshold # x shape: batch, number_bounds; threshold shape: batch, number_bounds - # print('threshold', threshold.shape) - return lambda x: (x > threshold).any(dim=1) + return lambda x: (x > threshold).any(dim=1, keepdim=True) def stop_criterion_batch_topk(threshold=0, k=1314): # x shape: batch, number_bounds; threshold shape: batch, number_bounds - # print('threshold', threshold.shape) return lambda x: (torch.kthvalue(x, k, dim=-1, keepdim=True).values > threshold).any(dim=1) +def multi_spec_keep_func_all(x): + return torch.all(x, dim=-1) + + user_data_dir = appdirs.user_data_dir('auto_LiRPA') if not os.path.exists(user_data_dir): try: @@ -90,51 +108,48 @@ def stop_criterion_batch_topk(threshold=0, k=1314): except: logger.error('Failed to create directory {}'.format(user_data_dir)) -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count class MultiAverageMeter(object): """Computes and stores the average and current value for multiple metrics""" def __init__(self): self.reset() + def reset(self): self.sum_meter = defaultdict(float) self.lasts = defaultdict(float) self.counts_meter = defaultdict(int) - def update(self, key, val, n=1): + self.batch_size = 1 + + def set_batch_size(self, batch_size): + self.batch_size = batch_size + + def update(self, key, val, n=None): + if val is None: + return + if n is None: + n = self.batch_size if isinstance(val, torch.Tensor): val = val.item() self.lasts[key] = val self.sum_meter[key] += val * n self.counts_meter[key] += n + def last(self, key): return self.lasts[key] + def avg(self, key): if self.counts_meter[key] == 0: return 0.0 else: return self.sum_meter[key] / self.counts_meter[key] + def __repr__(self): s = "" for k in self.sum_meter: s += "{}={:.4f} ".format(k, self.avg(k)) return s.strip() + class MultiTimer(object): """Count the time for each part of training.""" def __init__(self): @@ -159,9 +174,14 @@ def __repr__(self): s += "{}_time={:.3f} ".format(k, self.timer_total[k]) return s.strip() -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) + +class Flatten(nn.Flatten): + """Legacy Flatten class. + + It was previously created when nn.Flatten was not supported. Simply use + nn.Flatten in the future.""" + pass + class Unflatten(nn.Module): def __init__(self, wh): @@ -170,6 +190,25 @@ def __init__(self, wh): def forward(self, x): return x.view(x.size(0), -1, self.wh, self.wh) + +class Max(nn.Module): + + def __init__(self): + super(Max, self).__init__() + + def forward(self, x, y): + return torch.max(x, y) + + +class Min(nn.Module): + + def __init__(self): + super(Min, self).__init__() + + def forward(self, x, y): + return torch.min(x, y) + + def scale_gradients(optimizer, gradient_accumulation_steps, grad_clip=None): parameters = [] for param_group in optimizer.param_groups: @@ -180,12 +219,6 @@ def scale_gradients(optimizer, gradient_accumulation_steps, grad_clip=None): if grad_clip is not None: return torch.nn.utils.clip_grad_norm_(parameters, grad_clip) -def recursive_map (seq, func): - for item in seq: - if isinstance(item, Sequence): - yield type(item)(recursive_map(item, func)) - else: - yield func(item) # unpack tuple, dict, list into one single list # TODO: not sure if the order matches graph.inputs() @@ -202,14 +235,17 @@ def unpack_inputs(inputs, device=None): inputs = inputs.to(device) return [inputs] + def isnan(x): if isinstance(x, Patches): return False return torch.isnan(x).any() + def prod(x): return reduce(operator.mul, x, 1) + def batched_index_select(input, dim, index): # Assuming the input has a batch dimension. # index has dimensin [spec, batch]. @@ -231,16 +267,6 @@ def batched_index_select(input, dim, index): return torch.gather(input, dim, index) -def check_padding(x, padding): - if isinstance(padding, int): - return x, (padding, padding) - if len(padding) == 2: - return x, padding - if (padding[0] == padding[1]) and (padding[2] == padding[3]): - return x, (padding[0], padding[2]) - return F.pad(x, padding), (0, 0) - - def get_spec_matrix(X, y, num_classes): with torch.no_grad(): c = (torch.eye(num_classes).type_as(X)[y].unsqueeze(1) @@ -249,6 +275,7 @@ def get_spec_matrix(X, y, num_classes): c = (c[I].view(X.size(0), num_classes - 1, num_classes)) return c + def unravel_index( indices: torch.LongTensor, shape: Tuple[int, ...], @@ -274,14 +301,19 @@ def unravel_index( return list(reversed(coord)) -def get_A_shape(A): - if A is None: - return 'None' - if isinstance(A, Patches): - if A.patches is not None: - return A.patches.shape - else: - return A.shape - if isinstance(A, torch.Tensor): - return A.shape - return 'Unknown' + +def fill_template(out, template): + if template is None: + return out.popleft() + elif isinstance(template, (list, tuple)): + res = [] + for t in template: + res.append(fill_template(t)) + return tuple(res) if isinstance(template, tuple) else res + elif isinstance(template, dict): + res = {} + for key in template: + res[key] = fill_template(template[key]) + return res + else: + raise NotImplementedError diff --git a/doc/api.rst b/doc/api.rst index 5375cb5..b634e91 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -5,6 +5,7 @@ API Usage .. autofunction:: auto_LiRPA.BoundedModule.forward .. autofunction:: auto_LiRPA.BoundedModule.compute_bounds + .. autofunction:: auto_LiRPA.BoundedModule.save_intermediate .. autoclass:: auto_LiRPA.bound_ops.Bound diff --git a/doc/src/jacobian.md b/doc/src/jacobian.md new file mode 100644 index 0000000..601f74c --- /dev/null +++ b/doc/src/jacobian.md @@ -0,0 +1,54 @@ +# APIs for Jacobian + +## Specifying a Jacobian computation in the model + +When defining a computational graph by creating a `torch.nn.Module`, a Jacobian computation can be introduced with `JacobianOP.apply(y, x)` which denotes computing the Jacobian between `y` and `x`. + +For example, given a regular `model`, we may wrap it into a `JacobianWrapper` for computing the Jacobian between the output and the input of the model: +```python +import torch.nn as nn +from auto_LiRPA.bound_ops import JacobianOP + +class JacobianWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + y = self.model(x) + return JacobianOP.apply(y, x) +``` +Note that `JacobianOP.apply` only returns dummy values if we directly run this PyTorch model. +The actual Jacobian computation will be parsed when the model is wrapped into a `BoundedModule`. +See more [examples](../../examples/jacobian_new.py) including computing local Lipschitz constants and Jacobian-Vector products using `JacobianOP`. + +## Adding new operators + +To support the Jacobian bounds for a new operator, we need to ensure that there are bound operators implemented for the forward computation (the computation of the operator itself) and the backward computation (the computation of gradient) respectively. +Builtin operators are implemented in [auto_LiRPA/operators](../../auto_LiRPA/operators). +For example, for ReLU, we have [`BoundRelu`](../../auto_LiRPA/operators/relu.py) for the forward computation and [`BoundReluGrad`](../../auto_LiRPA/operators/gradient_bounds.py) for the backward computation. +Follow the [document](custom_op.md) to add new custom operators if necessary. + +Then for the forward operator, implement a `build_gradient_node` function. +This function tells the library how a gradient module should be created given the forward operator when building the computational graph with the Jacobian computation. +The function takes a single argument `grad_upstream` which is the upstream gradient during the gradient back-propagation. +The function should return three variables in a tuple, including `module_grad`, `grad_input` and `grad_extra_nodes`. +`module_grad` is a `torch.nn.Module` and the created module for the gradient computation. +`grad_input` contains a list of tensors denoting the gradients propagated to the input nodes. +`grad_extra_nodes` may contain a list of extra nodes if needed for gradient computation. +Note that for `grad_upstream` and `grad_input`, we only care about the shapes of the gradient tensors, and their values do not matter and can be dummy values. +See examples in [relu.py](../../auto_LiRPA/operators/relu.py) or [linear.py](../../auto_LiRPA/operators/linear.py). + +## References + +Please cite our paper for the Jacobian computation: +``` +@article{shi2022efficiently, + title={Efficiently computing local lipschitz constants of neural networks via bound propagation}, + author={Shi, Zhouxing and Wang, Yihan and Zhang, Huan and Kolter, J Zico and Hsieh, Cho-Jui}, + journal={Advances in Neural Information Processing Systems}, + volume={35}, + pages={2350--2364}, + year={2022} +} +``` diff --git a/examples/language/Transformer/Transformer.py b/examples/language/Transformer/Transformer.py index ddb9309..4493aee 100644 --- a/examples/language/Transformer/Transformer.py +++ b/examples/language/Transformer/Transformer.py @@ -16,42 +16,25 @@ from __future__ import absolute_import, division, print_function -import argparse -import csv import os -import random -import sys -import shutil -import scipy -import pickle -import pdb -import numpy as np import torch import torch.nn as nn -from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, - TensorDataset) -from torch.utils.data.distributed import DistributedSampler -from torch.nn import CrossEntropyLoss, MSELoss -from scipy.stats import pearsonr, spearmanr -from sklearn.metrics import matthews_corrcoef, f1_score - -from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.tokenization import BertTokenizer -from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule - -from Transformer.modeling import BertForSequenceClassification, BertConfig + +from Transformer.modeling import BertForSequenceClassification +from pytorch_pretrained_bert.modeling import BertConfig from Transformer.utils import convert_examples_to_features from language_utils import build_vocab from auto_LiRPA.utils import logger + class Transformer(nn.Module): def __init__(self, args, data_train): super().__init__() self.args = args self.max_seq_length = args.max_sent_length - self.drop_unk = args.drop_unk + self.drop_unk = args.drop_unk self.num_labels = args.num_classes - self.label_list = range(args.num_classes) + self.label_list = range(args.num_classes) self.device = args.device self.lr = args.lr @@ -86,19 +69,19 @@ def __init__(self, args, data_train): def save(self, epoch): self.model.model_from_embeddings = self.model_from_embeddings path = os.path.join(self.dir, "ckpt_{}".format(epoch)) - torch.save({ - 'state_dict_embeddings': self.model.embeddings.state_dict(), - 'state_dict_model_from_embeddings': self.model.model_from_embeddings.state_dict(), + torch.save({ + 'state_dict_embeddings': self.model.embeddings.state_dict(), + 'state_dict_model_from_embeddings': self.model.model_from_embeddings.state_dict(), 'epoch': epoch }, path) logger.info("Model saved to {}".format(path)) - + def build_optimizer(self): # update the original model with the converted model self.model.model_from_embeddings = self.model_from_embeddings param_group = [ {"params": [p[1] for p in self.model.named_parameters()], "weight_decay": 0.}, - ] + ] return torch.optim.Adam(param_group, lr=self.lr) def train(self): @@ -106,7 +89,7 @@ def train(self): self.model_from_embeddings.train() def eval(self): - self.model.eval() + self.model.eval() self.model_from_embeddings.eval() def get_input(self, batch): @@ -115,7 +98,7 @@ def get_input(self, batch): input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long).to(self.device) input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long).to(self.device) - segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long).to(self.device) + segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long).to(self.device) label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long).to(self.device) tokens = [f.tokens for f in features] @@ -126,6 +109,6 @@ def get_input(self, batch): def forward(self, batch): embeddings, extended_attention_mask, tokens, label_ids = self.get_input(batch) - logits = self.model_from_embeddings(embeddings, extended_attention_mask) + logits = self.model_from_embeddings(embeddings, extended_attention_mask) preds = torch.argmax(logits, dim=1) return preds \ No newline at end of file diff --git a/examples/language/Transformer/modeling.py b/examples/language/Transformer/modeling.py index 7e74ecd..0c2a837 100644 --- a/examples/language/Transformer/modeling.py +++ b/examples/language/Transformer/modeling.py @@ -17,24 +17,10 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import copy -import json -import math -import os -import shutil -import tarfile -import tempfile -import sys -from io import open - import torch from torch import nn -from torch.nn import CrossEntropyLoss - -from pytorch_pretrained_bert.file_utils import cached_path, WEIGHTS_NAME, CONFIG_NAME -from pytorch_pretrained_bert.modeling import ACT2FN, BertConfig, BertIntermediate, \ - BertSelfAttention, BertPreTrainedModel +from pytorch_pretrained_bert.modeling import BertIntermediate, BertSelfAttention, BertPreTrainedModel class BertLayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): @@ -59,7 +45,7 @@ def __init__(self, hidden_size, eps=1e-12): def forward(self, x): u = x.mean(-1, keepdim=True) x = x - u - return self.weight * x + self.bias + return self.weight * x + self.bias class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. @@ -71,7 +57,7 @@ def __init__(self, config, glove=None, vocab=None): self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size) self.config = config - + def forward(self, input_ids, token_type_ids=None): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) @@ -85,7 +71,7 @@ def forward(self, input_ids, token_type_ids=None): # position/token_type embedding disabled # embeddings = words_embeddings + position_embeddings + token_type_embeddings - + embeddings = words_embeddings return embeddings @@ -95,7 +81,7 @@ def __init__(self, config): self.config = config self.dense = nn.Linear(config.hidden_size, config.hidden_size) if hasattr(config, "layer_norm") and config.layer_norm == "no_var": - self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12) + self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12) else: self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -104,7 +90,7 @@ def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) if hidden_states.shape[-1] == input_tensor.shape[-1]: - hidden_states = hidden_states + input_tensor + hidden_states = hidden_states + input_tensor if hasattr(self.config, "layer_norm") and self.config.layer_norm == "no": pass else: @@ -129,7 +115,7 @@ def __init__(self, config): self.config = config self.dense = nn.Linear(config.intermediate_size, config.hidden_size) if hasattr(config, "layer_norm") and config.layer_norm == "no_var": - self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12) + self.LayerNorm = BertLayerNormNoVar(config.hidden_size, eps=1e-12) else: self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) @@ -196,7 +182,7 @@ def __init__(self, config): self.apply(self.init_bert_weights) def forward(self, embeddings, extended_attention_mask): - encoded_layers = self.encoder(embeddings, extended_attention_mask) + encoded_layers = self.encoder(embeddings, extended_attention_mask) sequence_output = encoded_layers[-1] pooled_output = self.pooler(sequence_output) return pooled_output @@ -212,7 +198,7 @@ def __init__(self, config, num_labels=2): self.layer_norm = config.layer_norm if hasattr(config, "layer_norm") and config.layer_norm == "no_var": - self.LayerNorm = BertLayerNormNoVar(config.embedding_size, eps=1e-12) + self.LayerNorm = BertLayerNormNoVar(config.embedding_size, eps=1e-12) else: self.LayerNorm = BertLayerNorm(config.embedding_size, eps=1e-12) @@ -226,7 +212,7 @@ def forward(self, embeddings, extended_attention_mask): else: embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) - + pooled_output = self.bert(embeddings, extended_attention_mask) pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) diff --git a/examples/language/Transformer/utils.py b/examples/language/Transformer/utils.py index 64a2aed..5ebcc75 100644 --- a/examples/language/Transformer/utils.py +++ b/examples/language/Transformer/utils.py @@ -13,21 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import torch -from sklearn.metrics import matthews_corrcoef, f1_score -from language_utils import tokenize, token_to_id - -def simple_accuracy(preds, labels): - return (preds == labels).mean() -def acc_and_f1(preds, labels): - acc = simple_accuracy(preds, labels) - f1 = f1_score(y_true=labels, y_pred=preds) - return { - "acc": acc, - "f1": f1, - "acc_and_f1": (acc + f1) / 2, - } +from language_utils import tokenize, token_to_id class InputExample(object): def __init__(self, guid, text_a, text_b=None, label=None): @@ -48,8 +35,6 @@ def convert_examples_to_features(examples, label_list, max_seq_length, vocab, drop_unk=False): #tokenizer): """Loads a data file into a list of `InputBatch`s.""" - label_map = {label : i for i, label in enumerate(label_list)} - features = [] all_tokens = tokenize(examples, vocab, max_seq_length - 2, drop_unk=drop_unk) for i in range(len(all_tokens)): @@ -77,5 +62,5 @@ def convert_examples_to_features(examples, label_list, max_seq_length, segment_ids=segment_ids, label_id=example["label"], tokens=tokens)) - - return features \ No newline at end of file + + return features diff --git a/examples/language/lstm.py b/examples/language/lstm.py index 0973ae2..66e242b 100644 --- a/examples/language/lstm.py +++ b/examples/language/lstm.py @@ -60,13 +60,13 @@ def __init__(self, args, data_train): self.checkpoint = 0 if args.load: - ckpt = torch.load(args.load, map_location=torch.device(self.device)) + ckpt = torch.load(args.load, map_location=torch.device(self.device)) self.embedding = torch.nn.Embedding(len(self.vocab), self.embedding_size) self.model_from_embeddings = LSTMFromEmbeddings(args, len(self.vocab)) self.model = self.embedding, LSTMFromEmbeddings(args, len(self.vocab)) self.embedding.load_state_dict(ckpt['state_dict_embedding']) self.model_from_embeddings.load_state_dict(ckpt['state_dict_model_from_embeddings']) - self.checkpoint = ckpt['epoch'] + self.checkpoint = ckpt['epoch'] else: self.embedding = torch.nn.Embedding(len(self.vocab), self.embedding_size) self.model_from_embeddings = LSTMFromEmbeddings(args, len(self.vocab)) @@ -79,7 +79,7 @@ def __init__(self, args, data_train): def save(self, epoch): path = os.path.join(self.dir, 'ckpt_{}'.format(epoch)) torch.save({ - 'state_dict_embedding': self.embedding.state_dict(), + 'state_dict_embedding': self.embedding.state_dict(), 'state_dict_model_from_embeddings': self.model_from_embeddings.state_dict(), 'epoch': epoch }, path) @@ -91,7 +91,7 @@ def build_optimizer(self): for m in self.model: for p in m.named_parameters(): param_group.append(p) - param_group = [{"params": [p[1] for p in param_group], "weight_decay": 0.}] + param_group = [{"params": [p[1] for p in param_group], "weight_decay": 0.}] return torch.optim.Adam(param_group, lr=self.lr) def get_input(self, batch): @@ -123,4 +123,4 @@ def train(self): self.model_from_embeddings.train() def eval(self): - self.model_from_embeddings.eval() \ No newline at end of file + self.model_from_embeddings.eval() diff --git a/examples/language/train.py b/examples/language/train.py index 8c6f5b2..d8cd168 100644 --- a/examples/language/train.py +++ b/examples/language/train.py @@ -12,7 +12,7 @@ from torch.nn import CrossEntropyLoss from torch.utils.tensorboard import SummaryWriter from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationSynonym, CrossEntropyWrapperMultiInput -from auto_LiRPA.utils import MultiAverageMeter, AverageMeter, logger, scale_gradients +from auto_LiRPA.utils import MultiAverageMeter, logger, scale_gradients from auto_LiRPA.eps_scheduler import * from Transformer.Transformer import Transformer from lstm import LSTM @@ -40,8 +40,8 @@ parser.add_argument('--model', type=str, default='transformer', choices=['transformer', 'lstm']) -parser.add_argument('--num_epochs', type=int, default=25) -parser.add_argument('--num_epochs_all_nodes', type=int, default=20) +parser.add_argument('--num_epochs', type=int, default=25) +parser.add_argument('--num_epochs_all_nodes', type=int, default=20) parser.add_argument('--eps_start', type=int, default=1) parser.add_argument('--eps_length', type=int, default=10) parser.add_argument('--log_interval', type=int, default=100) @@ -54,11 +54,11 @@ parser.add_argument('--lr', type=float, default=1e-4) parser.add_argument('--lr_decay', type=float, default=1) parser.add_argument('--grad_clip', type=float, default=10.0) -parser.add_argument('--num_classes', type=int, default=2) +parser.add_argument('--num_classes', type=int, default=2) parser.add_argument('--num_layers', type=int, default=1) parser.add_argument('--num_attention_heads', type=int, default=4) parser.add_argument('--hidden_size', type=int, default=64) -parser.add_argument('--embedding_size', type=int, default=64) +parser.add_argument('--embedding_size', type=int, default=64) parser.add_argument('--intermediate_size', type=int, default=128) parser.add_argument('--drop_unk', action='store_true') parser.add_argument('--hidden_act', type=str, default='relu') @@ -68,7 +68,7 @@ parser.add_argument('--dropout', type=float, default=0.1) parser.add_argument('--bound_opts_relu', type=str, default='zero-lb') -args = parser.parse_args() +args = parser.parse_args() writer = SummaryWriter(os.path.join(args.dir, 'log'), flush_secs=10) file_handler = logging.FileHandler(os.path.join(args.dir, 'log/train.log')) @@ -83,6 +83,11 @@ random.shuffle(data_test) data_test = data_test[:10] assert args.batch_size >= 10 + # Use double precision and deterministic algorithm for automatic testing. + os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' + torch.use_deterministic_algorithms(True) + torch.set_default_dtype(torch.float64) + logger.info('Dataset sizes: {}/{}/{}/{}'.format( len(data_train_all_nodes), len(data_train), len(data_dev), len(data_test))) @@ -100,7 +105,7 @@ elif args.model == 'lstm': dummy_mask = torch.zeros(1, args.max_sent_length, device=args.device) model = LSTM(args, data_train) - + dev_batches = get_batches(data_dev, args.batch_size) test_batches = get_batches(data_test, args.batch_size) @@ -117,8 +122,8 @@ if args.loss_fusion: bound_opts['loss_fusion'] = True model_loss = BoundedModule( - CrossEntropyWrapperMultiInput(model_ori), - (torch.zeros(1, dtype=torch.long), dummy_embeddings, dummy_mask), + CrossEntropyWrapperMultiInput(model_ori), + (torch.zeros(1, dtype=torch.long), dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device) ptb.model = model @@ -166,7 +171,7 @@ def step(model, ptb, batch, eps=1.0, train=False): # loss_fusion loss if args.method == 'IBP+backward_train': lb, ub = model_loss.compute_bounds( - x=(labels, embeddings, mask), aux=aux, + x=(labels, embeddings, mask), aux=aux, C=None, method='IBP+backward', bound_lower=False) else: raise NotImplementedError @@ -202,11 +207,11 @@ def step(model, ptb, batch, eps=1.0, train=False): acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float()) else: acc_robust, loss_robust = acc, loss - + if train or args.auto_test: loss_robust.backward() grad_embed = torch.autograd.grad( - embeddings_unbounded, model.word_embeddings.weight, + embeddings_unbounded, model.word_embeddings.weight, grad_outputs=embeddings.grad)[0] if model.word_embeddings.weight.grad is None: model.word_embeddings.weight.grad = grad_embed @@ -214,19 +219,23 @@ def step(model, ptb, batch, eps=1.0, train=False): model.word_embeddings.weight.grad += grad_embed if args.auto_test: + print('Saving results for automated tests.') + print(f'acc={acc}, loss={loss}, robust_acc={acc_robust}, robust_loss={loss_robust}') + print('gradients:') + print(grad_embed) with open('res_test.pkl', 'wb') as file: pickle.dump(( - float(acc), float(loss), float(acc_robust), float(loss_robust), + float(acc), float(loss), float(acc_robust), float(loss_robust), grad_embed.detach().numpy()), file) return acc, loss, acc_robust, loss_robust def train(epoch, batches, type): meter = MultiAverageMeter() - assert(optimizer is not None) + assert(optimizer is not None) train = type == 'train' if args.robust: - eps_scheduler.set_epoch_length(len(batches)) + eps_scheduler.set_epoch_length(len(batches)) if train: eps_scheduler.train() eps_scheduler.step_epoch() @@ -248,9 +257,9 @@ def train(epoch, batches, type): if (i + 1) % args.gradient_accumulation_steps == 0 or (i + 1) == len(batches): scale_gradients(optimizer, i % args.gradient_accumulation_steps + 1, args.grad_clip) optimizer.step() - optimizer.zero_grad() + optimizer.zero_grad() if lr_scheduler is not None: - lr_scheduler.step() + lr_scheduler.step() writer.add_scalar('loss_train_{}'.format(epoch), meter.avg('loss'), i + 1) writer.add_scalar('loss_robust_train_{}'.format(epoch), meter.avg('loss_rob'), i + 1) writer.add_scalar('acc_train_{}'.format(epoch), meter.avg('acc'), i + 1) @@ -267,7 +276,7 @@ def train(epoch, batches, type): if train: if args.loss_fusion: - state_dict_loss = model_loss.state_dict() + state_dict_loss = model_loss.state_dict() state_dict = {} for name in state_dict_loss: assert(name.startswith('model.')) @@ -275,7 +284,7 @@ def train(epoch, batches, type): model_ori.load_state_dict(state_dict) model_bound = BoundedModule( model_ori, (dummy_embeddings, dummy_mask), bound_opts=bound_opts, device=args.device) - model.model_from_embeddings = model_bound + model.model_from_embeddings = model_bound model.save(epoch) return meter.avg('acc_rob') @@ -303,7 +312,7 @@ def main(): res.append(acc_rob) logger.info('Verification results:') for i in range(len(res)): - logger.info('budget {} acc_rob {:.3f}'.format(i + 1, res[i])) + logger.info('budget {} acc_rob {:.3f}'.format(i + 1, res[i])) logger.info(res) else: train(None, test_batches, 'test') diff --git a/examples/requirements.txt b/examples/requirements.txt index 2aeee45..aa4c0f9 100644 --- a/examples/requirements.txt +++ b/examples/requirements.txt @@ -1,12 +1,5 @@ -thop>=0.0.31.post2004101309 -tensorboard>=1.14 -scikit_learn>=0.21 -torchvision>=0.9.1,<0.13 -torch>=1.8 -scipy>=1.3 +thop>=0.1.1.post2209072238 +tensorboard>=2.12 pytorch_pretrained_bert>=0.6 query>=0.1 -tqdm>=4.43 -matplotlib>=3.2 -sortedcontainers>=2.4 -psutil>=5.8 +matplotlib>=3.7.1 \ No newline at end of file diff --git a/examples/sequence/train.py b/examples/sequence/train.py index c778ebc..5a9f510 100644 --- a/examples/sequence/train.py +++ b/examples/sequence/train.py @@ -7,7 +7,7 @@ from lstm import LSTM from data_utils import load_data, get_batches from auto_LiRPA import BoundedModule, BoundedTensor, PerturbationLpNorm -from auto_LiRPA.utils import AverageMeter, logger, get_spec_matrix +from auto_LiRPA.utils import MultiAverageMeter, logger, get_spec_matrix parser = argparse.ArgumentParser() parser.add_argument("--seed", type=int, default=0) @@ -15,17 +15,17 @@ parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"]) parser.add_argument("--norm", type=int, default=np.inf) parser.add_argument("--eps", type=float, default=0.1) -parser.add_argument("--num_epochs", type=int, default=20) +parser.add_argument("--num_epochs", type=int, default=20) parser.add_argument("--batch_size", type=int, default=512) parser.add_argument("--num_slices", type=int, default=8) parser.add_argument("--hidden_size", type=int, default=256) -parser.add_argument("--num_classes", type=int, default=10) +parser.add_argument("--num_classes", type=int, default=10) parser.add_argument("--input_size", type=int, default=784) parser.add_argument("--lr", type=float, default=1e-2) parser.add_argument("--dir", type=str, default="model", help="directory to load or save the model") parser.add_argument("--num_epochs_warmup", type=int, default=10, help="number of epochs for the warmup stage when eps is linearly increased from 0 to the full value") parser.add_argument("--log_interval", type=int, default=10, help="interval of printing the log during training") -args = parser.parse_args() +args = parser.parse_args() ## Train or test one batch. @@ -52,7 +52,7 @@ def step(model, ptb, batch, eps=args.eps, train=False): # Report accuracy and robust accuracy. acc = (torch.argmax(logits, dim=-1) == y).float().mean() - acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float()) + acc_robust = 1 - torch.mean((lb < 0).any(dim=1).float()) if train: loss.backward() @@ -62,10 +62,10 @@ def step(model, ptb, batch, eps=args.eps, train=False): ## Train one epoch. def train(epoch): + meter = MultiAverageMeter() model.train() # Load data for a epoch. train_batches = get_batches(data_train, args.batch_size) - for a in avg: a.reset() eps_inc_per_step = 1.0 / (args.num_epochs_warmup * len(train_batches)) @@ -73,36 +73,38 @@ def train(epoch): # We increase eps linearly every batch. eps = args.eps * min(eps_inc_per_step * ((epoch - 1) * len(train_batches) + i + 1), 1.0) # Call the main training loop. - acc, acc_robust, loss = res = step(model, ptb, batch, eps=eps, train=True) + acc, acc_robust, loss = step(model, ptb, batch, eps=eps, train=True) # Optimize the loss. torch.nn.utils.clip_grad_norm_(model.core.parameters(), 5.0) optimizer.step() - optimizer.zero_grad() - # Print training statistics. - for k in range(3): - avg[k].update(res[k], len(batch)) + optimizer.zero_grad() + meter.set_batch_size(len(batch)) + meter.update('acc', acc) + meter.update('acc_rob', acc_robust) + meter.update('loss', loss) if (i + 1) % args.log_interval == 0: - logger.info("Epoch {}, training step {}/{}: acc {:.3f}, robust acc {:.3f}, loss {:.3f}, eps {:.3f}".format( - epoch, i + 1, len(train_batches), avg_acc.avg, avg_acc_robust.avg, avg_loss.avg, eps)) + logger.info("Epoch %d, training step %d/%d: %s, eps {:.3f}".format( + epoch, i + 1, len(train_batches), meter, eps)) model.save(epoch) ## Test accuracy and robust accuracy. def test(epoch, batches): + meter = MultiAverageMeter() model.eval() - for a in avg: a.reset() - for i, batch in enumerate(batches): - acc, acc_robust, loss = res = step(model, ptb, batch) - for k in range(3): - avg[k].update(res[k], len(batch)) - logger.info("Epoch {} test: acc {:.3f}, robust acc {:.3f}, loss {:.5f}".format( - epoch, avg_acc.avg, avg_acc_robust.avg, avg_loss.avg)) + for batch in batches: + acc, acc_robust, loss = step(model, ptb, batch) + meter.set_batch_size(len(batch)) + meter.update('acc', acc) + meter.udpate('acc_rob', acc_robust) + meter.update('loss', loss) + logger.info("Epoch %d test: {%s}".format(epoch, meter)) # Load MNIST dataset logger.info("Loading data...") data_train, data_test = load_data() logger.info("Dataset sizes: {}/{}".format(len(data_train), len(data_test))) -test_batches = get_batches(data_test, args.batch_size) +test_batches = get_batches(data_test, args.batch_size) # Set all random seeds. random.seed(args.seed) @@ -112,18 +114,15 @@ def test(epoch, batches): # Create a LSTM sequence classifier. logger.info("Creating LSTM model...") -model = LSTM(args).to(args.device) +model = LSTM(args).to(args.device) X, y = model.get_input(test_batches[0]) # Create the perturbation object once here, and we can reuse it. -ptb = PerturbationLpNorm(norm=args.norm, eps=args.eps) +ptb = PerturbationLpNorm(norm=args.norm, eps=args.eps) # Convert the LSTM to BoundedModule X = BoundedTensor(X, ptb) model.core = BoundedModule(model.core, (X,), device=args.device) optimizer = model.build_optimizer() -# Averaging accuracym robust accuracy and loss. -avg_acc, avg_acc_robust, avg_loss = avg = [AverageMeter() for i in range(3)] - # Main training loop. for t in range(model.checkpoint, args.num_epochs): train(t + 1) diff --git a/examples/vision/custom_op.py b/examples/vision/custom_op.py index 5b70fa0..ef0db08 100644 --- a/examples/vision/custom_op.py +++ b/examples/vision/custom_op.py @@ -51,7 +51,7 @@ def __init__(self, attr, inputs, output_index, options): def forward(self, x): return x + self.const - def bound_backward(self, last_lA, last_uA, x): + def bound_backward(self, last_lA, last_uA, x, *args, **kwargs): """ Backward mode bound propagation """ print('Calling bound_backward for custom::PlusConstant') def _bound_oneside(last_A): diff --git a/examples/vision/jacobian.py b/examples/vision/jacobian.py index 50b2bdf..f98b699 100644 --- a/examples/vision/jacobian.py +++ b/examples/vision/jacobian.py @@ -35,91 +35,108 @@ def build_model(in_ch=3, in_dim=32, width=32, linear_size=256): return model -torch.manual_seed(0) - -# Create a small model and load pre-trained parameters. -model_ori = build_model(width=4, linear_size=32) -model_ori.load_state_dict(torch.load('pretrained/cifar_2c2f.pth')) -device = 'cuda' if torch.cuda.is_available() else 'cpu' -model_ori = model_ori.to(device) -print('Model:', model_ori) - -# Prepare the dataset -test_data = datasets.CIFAR10('./data', train=False, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize( - mean=[0.4914, 0.4822, 0.4465], std=[0.2009, 0.2009, 0.2009])])) -x0 = test_data[0][0].unsqueeze(0).to(device) - - -# Example 1: Convert the model for Jacobian bound computation -model = BoundedModule(model_ori, x0, device=device) -model.augment_gradient_graph(x0) - -# Sanity check to ensure that the new graph matches the original gradient computation -y = model_ori(x0.requires_grad_(True)) -ret_ori = torch.autograd.grad(y.sum(), x0)[0].view(1, -1) -# After running augment_gradient_graph, the model takes an additional input -# (the second input) which is a linear mapping applied on the output of the -# model before computing the gradient. It is the same as "grad_outputs" in -# torch.autograd.grad, which is "the 'vector' in the vector-Jacobian product". -# Here, setting torch.ones(1, 10) is equivalent to computing the gradients for -# y.sum() above. -ret_new = model(x0, torch.ones(1, 10).to(x0)) -assert torch.allclose(ret_ori, ret_new) - -for eps in [0, 1./255, 4./255]: - # The input region considered is an Linf ball with radius eps around x0. - x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) - # Compute the Linf locaal Lipscphitz constant - lower, upper = model.compute_jacobian_bounds(x) - print(f'Gap between upper and lower Jacobian bound for eps={eps:.5f}', - (upper - lower).max()) - if eps == 0: - assert torch.allclose(ret_new, lower.sum(dim=0, keepdim=True)) - assert torch.allclose(ret_new, upper.sum(dim=0, keepdim=True)) - - -# Example 2: Convert the model for Linf local Lipschitz constant computation -model = BoundedModule(model_ori, x0, device=device) -# Set norm=np.inf for Linf local Lipschitz constant -model.augment_gradient_graph(x0, norm=np.inf) - -# Sanity check to ensure that the new graph matches the original gradient computation -y = model_ori(x0.requires_grad_(True)) -ret_ori = torch.autograd.grad(y.sum(), x0)[0].abs().sum().view(-1) -ret_new = model(x0, torch.ones(1, 10).to(x0)).view(-1) -assert torch.allclose(ret_ori, ret_new) - -for eps in [0, 1./255, 4./255]: - # The input region considered is an Linf ball with radius eps around x0. - x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) - # Compute the Linf locaal Lipschitz constant - result = model.compute_jacobian_bounds(x) - print(f'Linf local Lipschitz constant for eps={eps:.5f}', result) - - -# Example 3: Convert the model for Jacobian-Vector Product (JVP) computation -model = BoundedModule(model_ori, x0, device=device) -vector = torch.randn(x0.shape).to(x0) -# Set vector for JVP computation -model.augment_gradient_graph(x0, vector=vector) - -# Sanity check to ensure that the new graph matches the original JVP -def func(x0): - return model_ori(x0.requires_grad_(True)) -ret_ori = torch.autograd.functional.jvp(func, x0, vector)[-1].view(-1) -ret_new = torch.zeros(10).to(x0) -for i in range(10): - c = F.one_hot(torch.tensor([i], dtype=torch.long), 10).to(x0) - ret_new[i] = model(x0, c) -assert torch.allclose(ret_ori, ret_new) - -for eps in [0, 1./255, 4./255]: - # The input region considered is an Linf ball with radius eps around x0. - x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) - # Compute the JVP - lower, upper = model.compute_jacobian_bounds(x) - print(f'JVP lower bound for eps={eps:.5f}', lower.view(-1)) - print(f'JVP upper bound for eps={eps:.5f}', upper.view(-1)) +def compute_jacobians(model_ori, x0, bound_opts=None, device='cpu'): + """Compute Jacobians given a model and an input.""" + + results = [[] for _ in range(3)] + + model_ori = model_ori.to(device) + x0 = x0.to(device) + print('Model:', model_ori) + + # Example 1: Convert the model for Jacobian bound computation + model = BoundedModule(model_ori, x0, bound_opts=bound_opts, device=device) + model.augment_gradient_graph(x0) + + # Sanity check to ensure that the new graph matches the original gradient computation + y = model_ori(x0.requires_grad_(True)) + ret_ori = torch.autograd.grad(y.sum(), x0)[0].view(1, -1) + # After running augment_gradient_graph, the model takes an additional input + # (the second input) which is a linear mapping applied on the output of the + # model before computing the gradient. It is the same as "grad_outputs" in + # torch.autograd.grad, which is "the 'vector' in the vector-Jacobian product". + # Here, setting torch.ones(1, 10) is equivalent to computing the gradients for + # y.sum() above. + ret_new = model(x0, torch.ones(1, 10).to(x0)) + assert torch.allclose(ret_ori, ret_new) + + for eps in [0, 1./255, 4./255]: + # The input region considered is an Linf ball with radius eps around x0. + x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) + # Compute the Linf locaal Lipscphitz constant + lower, upper = model.compute_jacobian_bounds(x) + print(f'Gap between upper and lower Jacobian bound for eps={eps:.5f}', + (upper - lower).max()) + if eps == 0: + assert torch.allclose(ret_new, lower.sum(dim=0, keepdim=True)) + assert torch.allclose(ret_new, upper.sum(dim=0, keepdim=True)) + results[0].append((lower.detach(), upper.detach())) + + # Example 2: Convert the model for Linf local Lipschitz constant computation + model = BoundedModule(model_ori, x0, bound_opts=bound_opts, device=device) + # Set norm=np.inf for Linf local Lipschitz constant + model.augment_gradient_graph(x0, norm=np.inf) + + # Sanity check to ensure that the new graph matches the original gradient computation + y = model_ori(x0.requires_grad_(True)) + ret_ori = torch.autograd.grad(y.sum(), x0)[0].abs().sum().view(-1) + ret_new = model(x0, torch.ones(1, 10).to(x0)).view(-1) + assert torch.allclose(ret_ori, ret_new) + + for eps in [0, 1./255, 4./255]: + # The input region considered is an Linf ball with radius eps around x0. + x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) + # Compute the Linf locaal Lipschitz constant + result = model.compute_jacobian_bounds(x) + print(f'Linf local Lipschitz constant for eps={eps:.5f}', result) + results[1].append(result.detach()) + + # Example 3: Convert the model for Jacobian-Vector Product (JVP) computation + model = BoundedModule(model_ori, x0, bound_opts=bound_opts, device=device) + vector = torch.rand_like(x0) + # Set vector for JVP computation + model.augment_gradient_graph(x0, vector=vector) + + # Sanity check to ensure that the new graph matches the original JVP + def func(x0): + return model_ori(x0.requires_grad_(True)) + ret_ori = torch.autograd.functional.jvp(func, x0, vector)[-1].view(-1) + ret_new = torch.zeros(10).to(x0) + for i in range(10): + c = F.one_hot(torch.tensor([i], dtype=torch.long), 10).to(x0) + ret_new[i] = model(x0, c) + assert torch.allclose(ret_ori, ret_new) + + for eps in [0, 1./255, 4./255]: + # The input region considered is an Linf ball with radius eps around x0. + x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) + # Compute the JVP + lower, upper = model.compute_jacobian_bounds(x) + print(f'JVP lower bound for eps={eps:.5f}', lower.view(-1)) + print(f'JVP upper bound for eps={eps:.5f}', upper.view(-1)) + results[2].append((lower.detach(), upper.detach())) + + return results + + +def run_jacobian_examples(): + torch.manual_seed(0) + + # Create a small model and load pre-trained parameters. + model_ori = build_model(width=4, linear_size=32) + model_ori.load_state_dict(torch.load('pretrained/cifar_2c2f.pth')) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # Prepare the dataset + test_data = datasets.CIFAR10('./data', train=False, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.4914, 0.4822, 0.4465], std=[0.2009, 0.2009, 0.2009])])) + x0 = test_data[0][0].unsqueeze(0) + + return compute_jacobians(model_ori, x0, device=device) + + +if __name__ == '__main__': + run_jacobian_examples() \ No newline at end of file diff --git a/examples/vision/jacobian_new.py b/examples/vision/jacobian_new.py new file mode 100644 index 0000000..ed9116f --- /dev/null +++ b/examples/vision/jacobian_new.py @@ -0,0 +1,168 @@ +"""Examples of computing Jacobian bounds. + +We show examples of: +- Computing Jacobian bounds +- Computing Linf local Lipschitz constants +- Computing JVP bounds +""" + +import numpy as np +import torch +import torch.nn as nn +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import PerturbationLpNorm +from auto_LiRPA.utils import Flatten +from auto_LiRPA.jacobian import JacobianOP, GradNorm + + +def build_model(in_ch=3, in_dim=32): + model = nn.Sequential( + Flatten(), + nn.Linear(in_ch*in_dim**2, 100), + nn.ReLU(), + nn.Linear(100, 200), + nn.ReLU(), + nn.Linear(200, 10), + ) + return model + + +def example_jacobian(model_ori, x0, bound_opts, device): + """Example: computing Jacobian bounds.""" + + class JacobianWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x): + y = self.model(x) + return JacobianOP.apply(y, x) + + model = BoundedModule(JacobianWrapper(model_ori), x0, bound_opts=bound_opts, device=device) + + def func(x0): + return model_ori(x0.requires_grad_(True)) + ret_ori = torch.autograd.functional.jacobian(func, x0).squeeze(2) + ret_new = model(x0) + assert torch.allclose(ret_ori, ret_new) + + ret = [] + for eps in [0, 1./255, 4./255]: + x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) + lower, upper = model.compute_jacobian_bounds_new(x) + print(f'Gap between upper and lower Jacobian bound for eps={eps:.5f}', + (upper - lower).max()) + if eps == 0: + assert torch.allclose( + ret_new.view(-1), + lower.sum(dim=0, keepdim=True).view(-1)) + assert torch.allclose( + ret_new.view(-1), + upper.sum(dim=0, keepdim=True).view(-1)) + ret.append((lower.detach(), upper.detach())) + + return ret + + +def example_local_lipschitz(model_ori, x0, bound_opts, device): + """Example: computing Linf local Lipschitz constant.""" + + class LocalLipschitzWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.grad_norm = GradNorm(norm=1) + + def forward(self, x, mask): + y = self.model(x) + y_selected = y.matmul(mask) + jacobian = JacobianOP.apply(y_selected, x) + lipschitz = self.grad_norm(jacobian) + return lipschitz + + mask = torch.zeros(10, 1, device=device) + mask[1, 0] = 1 + model = BoundedModule(LocalLipschitzWrapper(model_ori), (x0, mask), + bound_opts=bound_opts, device=device) + + y = model_ori(x0.requires_grad_(True)) + ret_ori = torch.autograd.grad(y[:, 1].sum(), x0)[0].abs().flatten(1).sum(dim=-1).view(-1) + ret_new = model(x0, mask).view(-1) + assert torch.allclose(ret_ori, ret_new) + + ret = [] + for eps in [0, 1./255, 4./255]: + x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) + lip = [] + for i in range(mask.shape[0]): + mask.zero_() + mask[i, 0] = 1 + ub = model.compute_jacobian_bounds_new((x, mask), bound_lower=False)[1] + lip.append(ub) + lip = torch.concat(lip).max() + print(f'Linf local Lipschitz constant for eps={eps:.5f}: {lip.item()}') + ret.append(lip.detach()) + + return ret + + +def example_jvp(model_ori, x0, bound_opts, device): + """Example: computing Jacobian-Vector Product.""" + + class JVPWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + self.grad_norm = GradNorm(norm=1) + + def forward(self, x, v): + y = self.model(x) + jacobian = JacobianOP.apply(y, x).flatten(2) + jvp = (jacobian * v.flatten(1).unsqueeze(1)).sum(dim=-1) + return jvp + + vector = torch.rand_like(x0) + model = BoundedModule(JVPWrapper(model_ori), (x0, vector), + bound_opts=bound_opts, device=device) + + def func(x0): + return model_ori(x0.requires_grad_(True)) + ret_ori = torch.autograd.functional.jvp(func, x0, vector)[-1].view(-1) + ret_new = model(x0, vector) + assert torch.allclose(ret_ori, ret_new) + + ret = [] + for eps in [0, 1./255, 4./255]: + x = BoundedTensor(x0, PerturbationLpNorm(norm=np.inf, eps=eps)) + lb, ub = model.compute_jacobian_bounds_new((x, vector)) + print(f'JVP lower bound for eps={eps:.5f}: {lb}') + print(f'JVP upper bound for eps={eps:.5f}: {ub}') + ret.append((lb, ub)) + + return ret + + +def compute_jacobians_new(model_ori, x0, bound_opts=None, device='cpu'): + results = [[] for _ in range(3)] + + model_ori = model_ori.to(device) + x0 = x0.to(device) + print('Model:', model_ori) + + results[0] = example_jacobian(model_ori, x0, bound_opts, device) + results[1] = example_local_lipschitz(model_ori, x0, bound_opts, device) + results[2] = example_jvp(model_ori, x0, bound_opts, device) + + return results + + +if __name__ == '__main__': + torch.manual_seed(0) + + # Create a small model and load pre-trained parameters. + model_ori = build_model(in_dim=8) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + x0 = torch.randn(1, 3, 8, 8, device=device) + + compute_jacobians_new(model_ori, x0, device=device) diff --git a/examples/vision/models/densenet.py b/examples/vision/models/densenet.py index e36c506..a7d4b28 100644 --- a/examples/vision/models/densenet.py +++ b/examples/vision/models/densenet.py @@ -83,9 +83,8 @@ def forward(self, x): out = self.trans1(self.dense1(out)) out = self.trans2(self.dense2(out)) out = self.dense3(out) - # out = self.dense4(out) out = F.relu(self.bn(out)) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = F.relu(self.linear1(out)) out = self.linear2(out) diff --git a/examples/vision/models/densenet_imagenet.py b/examples/vision/models/densenet_imagenet.py index 6c467c4..e9f5ba2 100644 --- a/examples/vision/models/densenet_imagenet.py +++ b/examples/vision/models/densenet_imagenet.py @@ -83,9 +83,8 @@ def forward(self, x): out = self.trans1(self.dense1(out)) out = self.trans2(self.dense2(out)) out = self.dense3(out) - # out = self.dense4(out) out = F.relu(self.bn(out)) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = F.relu(self.linear1(out)) out = self.linear2(out) diff --git a/examples/vision/models/densenet_no_bn.py b/examples/vision/models/densenet_no_bn.py index c488130..ffea9e6 100644 --- a/examples/vision/models/densenet_no_bn.py +++ b/examples/vision/models/densenet_no_bn.py @@ -84,9 +84,8 @@ def forward(self, x): out = self.trans1(self.dense1(out)) out = self.trans2(self.dense2(out)) out = self.dense3(out) - # out = self.dense4(out) out = F.relu(out) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = F.relu(self.linear1(out)) out = self.linear2(out) diff --git a/examples/vision/models/feedforward.py b/examples/vision/models/feedforward.py index ffd252a..9fd3c98 100644 --- a/examples/vision/models/feedforward.py +++ b/examples/vision/models/feedforward.py @@ -19,7 +19,7 @@ def __init__(self, in_ch, in_dim, width=2, linear_size=256): def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) - x = x.view(x.size(0), -1) + x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) @@ -33,7 +33,7 @@ def __init__(self, in_ch, in_dim, width=1): self.fc2 = nn.Linear(256 * width, 10) def forward(self, x): - x = x.view(x.size(0), -1) + x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = self.fc2(x) return x @@ -47,7 +47,7 @@ def __init__(self, in_ch, in_dim, width=1): self.fc3 = nn.Linear(128 * width, 10) def forward(self, x): - x = x.view(x.size(0), -1) + x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) @@ -92,20 +92,15 @@ def __init__(self, in_ch, in_dim, width=1): self.fc5 = nn.Linear(128 * width, 10) def forward(self, x): - x = x.view(x.size(0), -1) + x = torch.flatten(x, 1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.relu(self.fc3(x)) - x = F.relu(self.fc4(x)) + x = F.relu(self.fc4(x)) x = self.fc5(x) return x -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) - - # Model can also be defined as a nn.Sequential def cnn_7layer(in_ch=3, in_dim=32, width=64, linear_size=512): model = nn.Sequential( @@ -119,7 +114,7 @@ def cnn_7layer(in_ch=3, in_dim=32, width=64, linear_size=512): nn.ReLU(), nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1), nn.ReLU(), - Flatten(), + nn.Flatten(), nn.Linear((in_dim//2) * (in_dim//2) * 2 * width, linear_size), nn.ReLU(), nn.Linear(linear_size,10) @@ -143,7 +138,7 @@ def cnn_7layer_bn(in_ch=3, in_dim=32, width=64, linear_size=512): nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1), nn.BatchNorm2d(2 * width), nn.ReLU(), - Flatten(), + nn.Flatten(), nn.Linear((in_dim//2) * (in_dim//2) * 2 * width, linear_size), nn.ReLU(), nn.Linear(linear_size,10) @@ -167,13 +162,13 @@ def cnn_7layer_bn_imagenet(in_ch=3, in_dim=32, width=64, linear_size=512): nn.Conv2d(2 * width, 2 * width, 3, stride=2, padding=1), nn.BatchNorm2d(2 * width), nn.ReLU(), - Flatten(), + nn.Flatten(), nn.Linear(25088, linear_size), nn.ReLU(), nn.Linear(linear_size,200) ) return model - + def cnn_6layer(in_ch, in_dim, width=32, linear_size=256): model = nn.Sequential( nn.Conv2d(in_ch, width, 3, stride=1, padding=1), @@ -184,7 +179,7 @@ def cnn_6layer(in_ch, in_dim, width=32, linear_size=256): nn.ReLU(), nn.Conv2d(2 * width, 2 * width, 3, stride=1, padding=1), nn.ReLU(), - Flatten(), + nn.Flatten(), nn.Linear((in_dim//2) * (in_dim//2) * 2 * width, linear_size), nn.ReLU(), nn.Linear(linear_size,10) diff --git a/examples/vision/models/mobilenet.py b/examples/vision/models/mobilenet.py index e054760..12f1970 100644 --- a/examples/vision/models/mobilenet.py +++ b/examples/vision/models/mobilenet.py @@ -72,7 +72,7 @@ def forward(self, x): out = F.relu((self.conv2(out))) # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10 out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = self.linear(out) return out diff --git a/examples/vision/models/resnet.py b/examples/vision/models/resnet.py index d058f3c..72ef803 100644 --- a/examples/vision/models/resnet.py +++ b/examples/vision/models/resnet.py @@ -31,11 +31,6 @@ def forward(self, x): return xs[-1] -class Flatten(nn.Module): - def forward(self, x): - return x.view(x.size(0), -1) - - def model_resnet(in_ch=3, in_dim=32, width=1, mult=16, N=1): def block(in_filters, out_filters, k, downsample): if not downsample: @@ -70,7 +65,7 @@ def block(in_filters, out_filters, k, downsample): conv2 + conv3 + conv4 + - [Flatten(), + [nn.Flatten(), nn.Linear(mult * 4 * width * 8 * 8, 1000), nn.ReLU(), nn.Linear(1000, 10)] diff --git a/examples/vision/models/resnet18.py b/examples/vision/models/resnet18.py index 05897cd..a8c6ba1 100644 --- a/examples/vision/models/resnet18.py +++ b/examples/vision/models/resnet18.py @@ -99,13 +99,13 @@ def forward(self, x): out = self.layer3(out) out = self.layer4(out) out = F.avg_pool2d(out, 4) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = self.linear(out) return out def ResNet18(in_planes=64): return ResNet(BasicBlock, [2, 2, 2, 2], in_planes=in_planes) - + if __name__ == "__main__": from thop import profile net = ResNet18(in_planes=64) diff --git a/examples/vision/models/resnext.py b/examples/vision/models/resnext.py index 681710a..78424e6 100644 --- a/examples/vision/models/resnext.py +++ b/examples/vision/models/resnext.py @@ -68,15 +68,11 @@ def _make_layer(self, num_blocks, stride): return nn.Sequential(*layers) def forward(self, x): - # out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.conv1(x)) - out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) - # out = self.layer4(out) - # out = F.avg_pool2d(out, 8) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = F.relu(self.linear1(out)) out = self.linear2(out) return out diff --git a/examples/vision/models/resnext_imagenet64.py b/examples/vision/models/resnext_imagenet64.py index 9ef4181..d056255 100644 --- a/examples/vision/models/resnext_imagenet64.py +++ b/examples/vision/models/resnext_imagenet64.py @@ -68,21 +68,15 @@ def _make_layer(self, num_blocks, stride): return nn.Sequential(*layers) def forward(self, x): - # out = F.relu(self.bn1(self.conv1(x))) out = F.relu(self.conv1(x)) - out = self.layer1(out) out = self.layer2(out) out = self.layer3(out) - # out = self.layer4(out) - # out = F.avg_pool2d(out, 8) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = F.relu(self.linear1(out)) out = self.linear2(out) return out - - def ResNeXt_imagenet64(): return ResNeXt(num_blocks=[2,2,2], cardinality=2, bottleneck_width=8) diff --git a/examples/vision/models/vnncomp_resnet.py b/examples/vision/models/vnncomp_resnet.py index 05c5fb0..04ad729 100644 --- a/examples/vision/models/vnncomp_resnet.py +++ b/examples/vision/models/vnncomp_resnet.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable + class BasicBlock(nn.Module): expansion = 1 @@ -98,7 +98,7 @@ def forward(self, x): out = self.layer1(out) if self.last_layer == "avg": out = self.avg2d(out) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = self.linear(out) elif self.last_layer == "dense": out = torch.flatten(out, 1) @@ -144,7 +144,7 @@ def forward(self, x): out = self.layer2(out) if self.last_layer == "avg": out = self.avg2d(out) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = self.linear(out) elif self.last_layer == "dense": out = torch.flatten(out, 1) @@ -159,6 +159,7 @@ def resnet2b(): def resnet4b(): return ResNet9(BasicBlock, num_blocks=2, in_planes=16, bn=False, last_layer="dense") + if __name__ == '__main__': print('ResNet-2B:\n', resnet2b()) print('ResNet-4B:\n', resnet4b()) diff --git a/examples/vision/models/wide_resnet_cifar.py b/examples/vision/models/wide_resnet_cifar.py index e996ac3..70d2574 100644 --- a/examples/vision/models/wide_resnet_cifar.py +++ b/examples/vision/models/wide_resnet_cifar.py @@ -97,7 +97,7 @@ def forward(self, x): out = F.relu(out) if self.use_pooling: out = F.avg_pool2d(out, 8) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = F.relu(self.linear1(out)) out = self.linear2(out) diff --git a/examples/vision/models/wide_resnet_imagenet64.py b/examples/vision/models/wide_resnet_imagenet64.py index 27c4d16..7a5418b 100644 --- a/examples/vision/models/wide_resnet_imagenet64.py +++ b/examples/vision/models/wide_resnet_imagenet64.py @@ -42,7 +42,7 @@ def forward(self, x): return out class Wide_ResNet(nn.Module): - def __init__(self, depth, widen_factor, dropout_rate, num_classes, + def __init__(self, depth, widen_factor, dropout_rate, num_classes, in_planes=16, in_dim=56): super(Wide_ResNet, self).__init__() self.in_planes = in_planes @@ -78,7 +78,7 @@ def forward(self, x): out = self.layer3(out) out = F.relu(self.bn1(out)) out = F.avg_pool2d(out, 7) - out = out.view(out.size(0), -1) + out = torch.flatten(out, 1) out = self.linear(out) return out diff --git a/examples/vision/pretrained/test_min_max.pth b/examples/vision/pretrained/test_min_max.pth new file mode 100644 index 0000000..22e0e98 Binary files /dev/null and b/examples/vision/pretrained/test_min_max.pth differ diff --git a/examples/vision/save_intermediate_bound.py b/examples/vision/save_intermediate_bound.py new file mode 100644 index 0000000..c5aefd8 --- /dev/null +++ b/examples/vision/save_intermediate_bound.py @@ -0,0 +1,56 @@ +""" +A simple example for saving intermediate bounds. +""" +import os +import torch +import torch.nn as nn +import torchvision +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import PerturbationLpNorm +from auto_LiRPA.utils import Flatten + +def mnist_model(): + model = nn.Sequential( + nn.Conv2d(1, 16, 4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(16, 32, 4, stride=2, padding=1), + nn.ReLU(), + Flatten(), + nn.Linear(32*7*7,100), + nn.ReLU(), + nn.Linear(100, 10) + ) + return model + +model = mnist_model() +# Optionally, load the pretrained weights. +checkpoint = torch.load( + os.path.join(os.path.dirname(__file__), 'pretrained/mnist_a_adv.pth'), + map_location=torch.device('cpu')) +model.load_state_dict(checkpoint) + +test_data = torchvision.datasets.MNIST( + './data', train=False, download=True, + transform=torchvision.transforms.ToTensor()) +# For illustration we only use 2 image from dataset +N = 2 +n_classes = 10 +image = test_data.data[:N].view(N,1,28,28) +true_label = test_data.targets[:N] +# Convert to float +image = image.to(torch.float32) / 255.0 +if torch.cuda.is_available(): + image = image.cuda() + model = model.cuda() + +lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device) +print('Running on', image.device) + +eps = 0.3 +norm = float("inf") +ptb = PerturbationLpNorm(norm = norm, eps = eps) +image = BoundedTensor(image, ptb) + +lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1, }}) +lb, ub = lirpa_model.compute_bounds(x=(image,), method='CROWN-Optimized') +save_dict = lirpa_model.save_intermediate('./mnist_a_adv_bounds.npy') \ No newline at end of file diff --git a/examples/vision/simple_verification.py b/examples/vision/simple_verification.py index 7085f01..103fbb2 100644 --- a/examples/vision/simple_verification.py +++ b/examples/vision/simple_verification.py @@ -134,4 +134,3 @@ def mnist_model(): print('margin bounds: {l:8.3f} <= f_{j}(x_0+delta) - f_{target}(x_0+delta) <= {u:8.3f}'.format( j=true_label[i], target=(true_label[i] + 1) % n_classes, l=lb[i][0].item(), u=ub[i][0].item())) print() - diff --git a/setup.py b/setup.py index 590b756..14ba8b5 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,9 @@ # Check PyTorch version pytorch_version_l = '1.11.0' -pytorch_version_u = '1.13.0' # excluded +pytorch_version_u = '2.1.0' # excluded msg_install_pytorch = (f'It is recommended to manually install PyTorch ' - f'(>={pytorch_version_u},<{pytorch_version_u}) suitable ' + f'(>={pytorch_version_l},<{pytorch_version_u}) suitable ' 'for your system ahead: https://pytorch.org/get-started.\n') try: import torch @@ -33,13 +33,13 @@ description='A library for Automatic Linear Relaxation based Perturbation Analysis (LiRPA) on general computational graphs, with a focus on adversarial robustness verification and certification of deep neural networks.', long_description=long_description, long_description_content_type='text/markdown', - url='https://github.com/KaidiXu/auto_LiRPA', - author='Kaidi Xu, Zhouxing Shi, Huan Zhang, Yihan Wang, Shiqi Wang, Linyi Li, Jinqi (Kathryn) Chen, Zhuolin Yang', - author_email='xu.kaid@husky.neu.edu, zhouxingshichn@gmail.com, huan@huan-zhang.com, wangyihan617@gmail.com, sw3215@columbia.edu,linyi2@illinois.edu,jinqic@cs.cmu.edu,zhuolin5@illinois.edu', + url='https://github.com/Verified-Intelligence/auto_LiRPA', + author='Kaidi Xu, Zhouxing Shi, Huan Zhang, Yihan Wang, Shiqi Wang, Linyi Li, Jinqi (Kathryn) Chen, Zhuolin Yang, Christopher Brix, Xiangru Zhong, Qirui Jin, Zhuowen Yuan', + author_email='xu.kaid@husky.neu.edu, zhouxingshichn@gmail.com, huan@huan-zhang.com, wangyihan617@gmail.com, sw3215@columbia.edu, linyi2@illinois.edu, jinqic@cs.cmu.edu, zhuolin5@illinois.edu, brix@cs.rwth-aachen.de, xiangruzh0915@gmail.com, qiruijin@umich.edu, realzhuowen@gmail.com', packages=find_packages(), install_requires=[ f'torch>={pytorch_version_l},<{pytorch_version_u}', - 'torchvision>=0.9,<0.14', + 'torchvision>=0.9', 'numpy>=1.20', 'packaging>=20.0', 'pytest>=5.0', diff --git a/tests/data/conv1d_test_data_3-0-2 b/tests/data/conv1d_test_data_3-0-2 new file mode 100644 index 0000000..e3c0ac9 Binary files /dev/null and b/tests/data/conv1d_test_data_3-0-2 differ diff --git a/tests/data/conv1d_test_data_3-0-3 b/tests/data/conv1d_test_data_3-0-3 new file mode 100644 index 0000000..3885ad2 Binary files /dev/null and b/tests/data/conv1d_test_data_3-0-3 differ diff --git a/tests/data/conv1d_test_data_3-1-2 b/tests/data/conv1d_test_data_3-1-2 new file mode 100644 index 0000000..1641b50 Binary files /dev/null and b/tests/data/conv1d_test_data_3-1-2 differ diff --git a/tests/data/conv1d_test_data_3-1-3 b/tests/data/conv1d_test_data_3-1-3 new file mode 100644 index 0000000..5c184f3 Binary files /dev/null and b/tests/data/conv1d_test_data_3-1-3 differ diff --git a/tests/data/conv1d_test_data_4-0-2 b/tests/data/conv1d_test_data_4-0-2 new file mode 100644 index 0000000..ef6eabe Binary files /dev/null and b/tests/data/conv1d_test_data_4-0-2 differ diff --git a/tests/data/conv1d_test_data_4-0-3 b/tests/data/conv1d_test_data_4-0-3 new file mode 100644 index 0000000..e5c647c Binary files /dev/null and b/tests/data/conv1d_test_data_4-0-3 differ diff --git a/tests/data/conv1d_test_data_4-1-2 b/tests/data/conv1d_test_data_4-1-2 new file mode 100644 index 0000000..91de6dc Binary files /dev/null and b/tests/data/conv1d_test_data_4-1-2 differ diff --git a/tests/data/conv1d_test_data_4-1-3 b/tests/data/conv1d_test_data_4-1-3 new file mode 100644 index 0000000..785b1dd Binary files /dev/null and b/tests/data/conv1d_test_data_4-1-3 differ diff --git a/tests/data/jacobian_test_data b/tests/data/jacobian_test_data index 51d32b2..f366216 100644 Binary files a/tests/data/jacobian_test_data and b/tests/data/jacobian_test_data differ diff --git a/tests/data/language_test_data b/tests/data/language_test_data index c00ac4b..11d436a 100644 Binary files a/tests/data/language_test_data and b/tests/data/language_test_data differ diff --git a/tests/data/min_max_test_data b/tests/data/min_max_test_data new file mode 100644 index 0000000..74fda03 Binary files /dev/null and b/tests/data/min_max_test_data differ diff --git a/tests/data/test_save_data b/tests/data/test_save_data new file mode 100644 index 0000000..e0e3eb6 Binary files /dev/null and b/tests/data/test_save_data differ diff --git a/tests/requirements.txt b/tests/requirements.txt deleted file mode 100644 index 9da241c..0000000 --- a/tests/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -onnxruntime>=1.12 -git+https://github.com/KaidiXu/onnx2pytorch.git \ No newline at end of file diff --git a/tests/test_1d_activation.py b/tests/test_1d_activation.py index 2a0a552..934ccc5 100644 --- a/tests/test_1d_activation.py +++ b/tests/test_1d_activation.py @@ -1,4 +1,5 @@ """Test one dimensional activation functions (e.g., ReLU, tanh, exp, sin, etc)""" +import pytest import torch import torch.nn as nn from testcase import TestCase @@ -6,6 +7,7 @@ from auto_LiRPA.perturbations import * from auto_LiRPA.utils import logger + # Wrap the computation with a nn.Module class test_model(nn.Module): def __init__(self, act_func): @@ -15,17 +17,39 @@ def __init__(self, act_func): def forward(self, x): return self.act_func(x) +def pow_2(x): + return torch.pow(x, 2) + +def pow_3(x): + return torch.pow(x, 3) + +class GELUOp(torch.autograd.Function): + @staticmethod + def symbolic(g, x): + return g.op('custom::Gelu', x) + + @staticmethod + def forward(ctx, x): + return torch.nn.functional.gelu(x) + +def GELU(x): + return GELUOp.apply(x) + class Test1DActivation(TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) - def create_test(self, act_func, low, high, ntests=10000, nsamples=1000, method='IBP'): - print(f'Testing activation {act_func}') + def create_test(self, act_func, low, high, ntests=1000, nsamples=1000, + method='IBP'): + print(f'Testing activation {act_func} (method {method})') model = test_model(act_func) image = torch.zeros(1, ntests) - bounded_model = BoundedModule(model, image) + bounded_model = BoundedModule( + model, image, bound_opts={ + 'optimize_bound_args': {'iteration': 2}, + }) # Generate randomly bounded inputs. p = torch.rand(1, ntests) * (high - low ) + low @@ -61,93 +85,82 @@ def lookup(l, u): real_lb = real_lb.view(*shape) real_ub = real_ub.view(*shape) return real_lb, real_ub - # These are reference results. IBP results should be very close to these. Linear bound results can be looser than these. + + # These are reference results. IBP results should be very close to these. + # Linear bound results can be looser than these. ref_forward = model(input_center) ref_output_lb, ref_output_ub = lookup(input_lb, input_ub) # Get bounding results. forward = bounded_model(ptb_data) - output_lb, output_ub = bounded_model.compute_bounds(x=(ptb_data,), method = method) + if act_func in [torch.sin, torch.cos]: + bounded_model.set_bound_opts({ + 'optimize_bound_args': {'iteration': 2, 'init_alpha': False}, + }) + bounded_model.init_alpha(x=(ptb_data,), skip_bound_compute=True) + node = bounded_model.optimizable_activations[0] + shape = node.alpha['/1'].data[0:2].shape + node.alpha['/1'].data[8:10, :] = (node.alpha['/1'][8:10, :] + - node.tp_right_lower_init['/1']) * torch.rand(*shape) + node.tp_right_lower_init['/1'] + node.alpha['/1'].data[10:12, :] = (node.alpha['/1'][10:12, :] + - node.tp_right_upper_init['/1']) * torch.rand(*shape) + node.tp_right_upper_init['/1'] + + output_lb, output_ub = bounded_model.compute_bounds( + x=(ptb_data,), method=method) + bounded_model.set_bound_opts({ + 'optimize_bound_args': {'iteration': 2, 'init_alpha': True}, + }) # Compare. assert torch.allclose(forward, ref_forward) for i in range(ntests): show = False if output_ub[0,i] < ref_output_ub[0,i] - 1e-5: - logger.warn(f'upper bound is wrong {ref_output_ub[0,i] - output_ub[0,i]}') + logger.warning(f'upper bound is wrong {ref_output_ub[0,i] - output_ub[0,i]}') show = True if output_lb[0,i] > ref_output_lb[0,i] + 1e-5: - logger.warn(f'lower bound is wrong {output_lb[0,i] - ref_output_lb[0,i]}') + logger.warning(f'lower bound is wrong {output_lb[0,i] - ref_output_lb[0,i]}') show = True if show: - logger.warn(f'input_lb={input_lb[0,i]:8.3f}, input_ub={input_ub[0,i]:8.3f}, lb={output_lb[0,i]:8.3f}, ref_lb={ref_output_lb[0,i]:8.3f}, ub={output_ub[0,i]:8.3f}, ref_ub={ref_output_ub[0,i]:8.3f}') + logger.warning(f'input_lb={input_lb[0,i]:8.3f}, input_ub={input_ub[0,i]:8.3f}, lb={output_lb[0,i]:8.3f}, ref_lb={ref_output_lb[0,i]:8.3f}, ub={output_ub[0,i]:8.3f}, ref_ub={ref_output_ub[0,i]:8.3f}') assert torch.all(output_ub + 1e-5 >= ref_output_ub) assert torch.all(output_lb - 1e-5 <= ref_output_lb) - - def _single(self): - model = test_model(torch.sin) - image = torch.zeros(1, 1) - bounded_model = BoundedModule(model, image) - - input_lb = torch.tensor([2.817]) - input_ub = torch.tensor([5.196]) - input_center = (input_lb + input_ub) / 2.0 - ptb = PerturbationLpNorm(norm=float("inf"), eps=None, x_L=input_lb, x_U=input_ub) - ptb_data = BoundedTensor(input_center, ptb) - - # Get bounding results. - forward = bounded_model(ptb_data) - output_lb, output_ub = bounded_model.compute_bounds(x=(ptb_data,), method = 'CROWN') - print(output_lb, output_ub) - - def test_relu(self): - self.create_test(act_func=torch.nn.functional.relu, low=-10, high=10, method='IBP') - self.create_test(act_func=torch.nn.functional.relu, low=-10, high=10, method='CROWN') - - - def test_exp(self): - self.create_test(act_func=torch.exp, low=-3, high=3, method='IBP') - self.create_test(act_func=torch.exp, low=-3, high=3, method='CROWN') - - - def test_reciprocal(self): - # So far only positive values are supported. - self.create_test(act_func=torch.reciprocal, low=0.01, high=10, method='IBP') - self.create_test(act_func=torch.reciprocal, low=0.01, high=10, method='CROWN') - - - def test_tanh(self): - self.create_test(act_func=torch.tanh, low=-5, high=5, method='IBP') - self.create_test(act_func=torch.tanh, low=-5, high=5, method='CROWN') - - - def test_sin(self): - self.create_test(act_func=torch.sin, low=-10, high=10, method='IBP') - self.create_test(act_func=torch.sin, low=-10, high=10, method='CROWN') - - - def test_cos(self): - self.create_test(act_func=torch.cos, low=-10, high=10, method='IBP') - self.create_test(act_func=torch.cos, low=-10, high=10, method='CROWN') - - def test_arctan(self): - self.create_test(act_func=torch.arctan, low=-10, high=10, method='IBP') - self.create_test(act_func=torch.arctan, low=-10, high=10, method='CROWN') - + @pytest.mark.skip(reason="Known issue: https://github.com/Verified-Intelligence/Verifier_Development/issues/164") def test_tan(self): # Test tan(x) in different periods. for i in range(-5, 5): - self.create_test(act_func=torch.arctan, low=-0.5*torch.pi + i*torch.pi + 1e-20, high=0.5*torch.pi + i*torch.pi - 1e-20, method='IBP') - self.create_test(act_func=torch.arctan, low=-0.5*torch.pi + i*torch.pi + 1e-20, high=0.5*torch.pi + i*torch.pi - 1e-20, method='CROWN') + self.create_test( + act_func=torch.tan, + low=-0.5*torch.pi + i*torch.pi + 1e-20, + high=0.5*torch.pi + i*torch.pi - 1e-20, method='IBP') + self.create_test( + act_func=torch.tan, + low=-0.5*torch.pi + i*torch.pi + 1e-20, + high=0.5*torch.pi + i*torch.pi - 1e-20, method='CROWN') + + def test_acts(self): + for act_func in [torch.nn.functional.relu, + torch.sin, torch.cos, + torch.tanh, torch.arctan, + torch.exp, pow_2, pow_3, + torch.sign, GELU]: + low, high = -10, 10 + if act_func == torch.reciprocal: + # So far only positive values are supported. + low = 0.01 + self.create_test(act_func=act_func, low=low, high=high, method='IBP') + self.create_test(act_func=act_func, low=low, high=high, method='CROWN') + if act_func not in [torch.exp, torch.sign, torch.sin, torch.cos]: + # Use optimized bounds + self.create_test(act_func=act_func, low=low, high=high, + method='CROWN-Optimized') + if act_func in [torch.sin, torch.cos]: + test_samples = 10 + for _ in range(test_samples): + self.create_test(act_func=act_func, low=low, high=high, method='CROWN-Optimized') + if __name__ == '__main__': testcase = Test1DActivation() - testcase.test_relu() - testcase.test_reciprocal() - testcase.test_exp() - testcase.test_tanh() - testcase.test_sin() - testcase.test_cos() - testcase.test_arctan() - testcase.test_tan() + testcase.test_acts() diff --git a/tests/test_2d_activation.py b/tests/test_2d_activation.py new file mode 100644 index 0000000..9661f1b --- /dev/null +++ b/tests/test_2d_activation.py @@ -0,0 +1,131 @@ +"""Test two dimensional activation functions (e.g., min, max, etc)""" +import tqdm +import torch +import torch.nn as nn +from testcase import TestCase +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import * +from auto_LiRPA.utils import logger + + +# Wrap the computation with a nn.Module +class test_model(nn.Module): + def __init__(self, act_func): + super().__init__() + self.act_func = act_func + + def forward(self, x, y): + return self.act_func(x, y) + + +def mul(x, y): + return x * y + + +class Test2DActivation(TestCase): + def __init__(self, methodName='runTest'): + super().__init__(methodName) + + def create_test(self, act_func, low_x, high_x, low_y, high_y, + ntests=10000, nsamples=1000, method='IBP'): + print(f'Testing activation {act_func}') + + model = test_model(act_func) + image = torch.zeros(2, ntests) + bounded_model = BoundedModule(model, (image[0], image[1]), device=torch.device('cpu')) + + # Generate randomly bounded inputs. + p_x = torch.rand(1, ntests) * (high_x - low_x) + low_x + q_x = torch.rand(1, ntests) * (high_x - low_x) + low_x + input_lb_x = torch.min(p_x, q_x) + input_ub_x = torch.max(p_x, q_x) + input_center_x = (input_lb_x + input_ub_x) / 2.0 + ptb_x = PerturbationLpNorm(x_L=input_lb_x, x_U=input_ub_x) + ptb_data_x = BoundedTensor(input_center_x, ptb_x) + + p_y = torch.rand(1, ntests) * (high_y - low_y) + low_y + q_y = torch.rand(1, ntests) * (high_y - low_y) + low_y + input_lb_y = torch.min(p_y, q_y) + input_ub_y = torch.max(p_y, q_y) + input_center_y = (input_lb_y + input_ub_y) / 2.0 + ptb_y = PerturbationLpNorm(x_L=input_lb_y, x_U=input_ub_y) + ptb_data_y = BoundedTensor(input_center_y, ptb_y) + + # Generate reference results. + range_xy = torch.linspace(start=low_x, end=high_x, steps=nsamples+1) + table = torch.empty([range_xy.shape[0], range_xy.shape[0]]) + for i in range(range_xy.shape[0]): + x = range_xy[i] + table_y = act_func(x, torch.linspace(start=low_y, end=high_y, steps=nsamples+1)) + table[i] = table_y + def lookup(l_x, u_x, l_y, u_y): + assert torch.all(u_x <= high_x) + assert torch.all(l_x >= low_x) + assert torch.all(u_y <= high_y) + assert torch.all(l_y >= low_y) + shape = l_x.size() + l_x = l_x.squeeze() + u_x = u_x.squeeze() + l_y = l_y.squeeze() + u_y = u_y.squeeze() + # select all sample points between l and u. + low_index_x = torch.ceil((l_x - low_x) / (high_x - low_x) * nsamples).int() # Make sure we do not have index 0. + high_index_x = torch.floor((u_x - low_x) / (high_x - low_x) * nsamples).int() + low_index_y = torch.ceil((l_y - low_y) / (high_y - low_y) * nsamples).int() # Make sure we do not have index 0. + high_index_y = torch.floor((u_y - low_y) / (high_y - low_y) * nsamples).int() + real_lb = torch.empty_like(l_x) + real_ub = torch.empty_like(u_x) + for i, (li_x, hi_x) in enumerate(zip(low_index_x, high_index_x)): + li_y = low_index_y[i] + hi_y = high_index_y[i] + if li_x == hi_x + 1 or li_y == hi_y + 1: + # Not enough precision. l and u are too close so we cannot tell. + real_lb[i] = float("inf") + real_ub[i] = float("-inf") + else: + selected = table[li_x : hi_x+1, li_y : hi_y+1].reshape(-1) + real_lb[i] = torch.min(selected) + real_ub[i] = torch.max(selected) + real_lb = real_lb.view(*shape) + real_ub = real_ub.view(*shape) + return real_lb, real_ub + # These are reference results. IBP results should be very close to these. Linear bound results can be looser than these. + ref_forward = model(input_center_x, input_center_y) + ref_output_lb, ref_output_ub = lookup(input_lb_x, input_ub_x, input_lb_y, input_ub_y) + + # Get bounding results. + forward = bounded_model(ptb_data_x, ptb_data_y) + output_lb, output_ub = bounded_model.compute_bounds(x=(ptb_data_x, ptb_data_y), method = method) + + # Compare. + assert torch.allclose(forward, ref_forward) + for i in tqdm.tqdm(range(ntests)): + show = False + if output_ub[0,i] < ref_output_ub[0,i] - 1e-5: + logger.warning(f'upper bound is wrong {ref_output_ub[0,i] - output_ub[0,i]}') + show = True + if output_lb[0,i] > ref_output_lb[0,i] + 1e-5: + logger.warning(f'lower bound is wrong {output_lb[0,i] - ref_output_lb[0,i]}') + show = True + if show: + logger.warning(f'input_lb_x={input_lb_x[0,i]:8.3f}, input_ub_x={input_ub_x[0,i]:8.3f},input_lb_y={input_lb_y[0,i]:8.3f}, input_ub_y={input_ub_y[0,i]:8.3f}, lb={output_lb[0,i]:8.3f}, ref_lb={ref_output_lb[0,i]:8.3f}, ub={output_ub[0,i]:8.3f}, ref_ub={ref_output_ub[0,i]:8.3f}') + assert torch.all(output_ub + 1e-5 >= ref_output_ub) + assert torch.all(output_lb - 1e-5 <= ref_output_lb) + + def test_max(self): + self.create_test(act_func=torch.max, low_x=-10, high_x=5, low_y=-1, high_y=10, method='IBP') + self.create_test(act_func=torch.max, low_x=-10, high_x=5, low_y=-1, high_y=10, method='CROWN') + + def test_min(self): + self.create_test(act_func=torch.min, low_x=-10, high_x=5, low_y=-1, high_y=10, method='IBP') + self.create_test(act_func=torch.min, low_x=-10, high_x=5, low_y=-1, high_y=10, method='CROWN') + + def test_mul(self): + self.create_test(act_func=mul, low_x=-10, high_x=5, low_y=-1, high_y=10, method='IBP') + self.create_test(act_func=mul, low_x=-10, high_x=5, low_y=-1, high_y=10, method='CROWN') + +if __name__ == '__main__': + testcase = Test2DActivation() + testcase.test_max() + testcase.test_min() + testcase.test_mul() diff --git a/tests/test_bound_ops.py b/tests/test_bound_ops.py index a49adc4..59e8564 100644 --- a/tests/test_bound_ops.py +++ b/tests/test_bound_ops.py @@ -5,17 +5,18 @@ from testcase import TestCase -"""Dummy node for testing""" class Dummy: + """Dummy node for testing""" def __init__(self, lower, upper=None, perturbed=False): self.lower = lower self.upper = upper if upper is not None else lower self.perturbed = perturbed self.output_shape = lower.shape -class TestBoundOp(TestCase): + +class TestBoundOp(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, + super().__init__(methodName, seed=1, ref_path='data/bound_ops_data', generate=generate) @@ -25,20 +26,23 @@ def test(self): dim_final = 7 dim_output = 11 dim_input = 11 - + # multiplication of [batch_size, dim_input] and [dim_output, dim_input]^T weight = torch.randn(dim_output, dim_input, device=device) bias = torch.randn(dim_output, device=device) data_in = torch.randn(batch_size, dim_input, device=device) data_in_delta = torch.randn(batch_size, dim_input, device=device) - dummy_in = Dummy(data_in - torch.abs(data_in_delta), data_in + torch.abs(data_in_delta), True) + dummy_in = Dummy( + data_in - torch.abs(data_in_delta), + data_in + torch.abs(data_in_delta), True) dummy_weight = Dummy(weight) dummy_bias = Dummy(bias) op = BoundLinear( - attr={}, + attr={}, inputs=[dummy_in, dummy_weight, dummy_bias], output_index=0, options={}) + op.batch_dim = 0 # test `forward` data_out = op(data_in, weight, bias) @@ -67,26 +71,32 @@ def test(self): bound_weight = LinearBound(None, None, None, None, dummy_weight.lower, dummy_weight.upper) bound_bias = LinearBound(None, None, None, None, dummy_bias.lower, dummy_bias.upper) bound_out = op.bound_forward(dim_final, bound_in, bound_weight, bound_bias) - self.assertEqual(bound_out.lw, - bound_in.lw.matmul(weight.t().clamp(min=0)) + bound_in.uw.matmul(weight.t().clamp(max=0))) - self.assertEqual(bound_out.uw, - bound_in.uw.matmul(weight.t().clamp(min=0)) + bound_in.lw.matmul(weight.t().clamp(max=0))) - self.assertEqual(bound_out.lb, - bound_in.lb.matmul(weight.t().clamp(min=0)) + bound_in.ub.matmul(weight.t().clamp(max=0)) + bias) - self.assertEqual(bound_out.ub, - bound_in.ub.matmul(weight.t().clamp(min=0)) + bound_in.lb.matmul(weight.t().clamp(max=0)) + bias) + self.assertEqual( + bound_out.lw, bound_in.lw.matmul(weight.t().clamp(min=0)) + + bound_in.uw.matmul(weight.t().clamp(max=0))) + self.assertEqual( + bound_out.uw, bound_in.uw.matmul(weight.t().clamp(min=0)) + + bound_in.lw.matmul(weight.t().clamp(max=0))) + self.assertEqual( + bound_out.lb, bound_in.lb.matmul(weight.t().clamp(min=0)) + + bound_in.ub.matmul(weight.t().clamp(max=0)) + bias) + self.assertEqual( + bound_out.ub, bound_in.ub.matmul(weight.t().clamp(min=0)) + + bound_in.lb.matmul(weight.t().clamp(max=0)) + bias) # test `interval_propagate` bound_in = ( - torch.randn(*data_in.shape, device=device), + torch.randn(*data_in.shape, device=device), torch.randn(*data_in.shape, device=device)) bound_weight = (bound_weight.lower, bound_weight.upper) bound_bias = (bound_bias.lower, bound_bias.upper) bound_out = op.interval_propagate(bound_in, bound_weight, bound_bias) - self.assertEqual(bound_out[0], - bound_in[0].matmul(weight.t().clamp(min=0)) + bound_in[1].matmul(weight.t().clamp(max=0)) + bias) - self.assertEqual(bound_out[1], - bound_in[1].matmul(weight.t().clamp(min=0)) + bound_in[0].matmul(weight.t().clamp(max=0)) + bias) + self.assertEqual(bound_out[0], + bound_in[0].matmul(weight.t().clamp(min=0)) + + bound_in[1].matmul(weight.t().clamp(max=0)) + bias) + self.assertEqual(bound_out[1], + bound_in[1].matmul(weight.t().clamp(min=0)) + + bound_in[0].matmul(weight.t().clamp(max=0)) + bias) # test weight perturbation # `bound_backward` @@ -112,13 +122,15 @@ def test(self): # legacy reference if ref.shape[0] == batch_size: ref = ref.transpose(0, 1) - self.assertEqual(A[i][j], ref) + self.assertEqual(A[i][j], ref) lbias, ubias = lbias.transpose(0, 1), ubias.transpose(0, 1) self.assertEqual(lbias, lbias_ref) self.assertEqual(ubias, ubias_ref) - self.assertEqual(bound_out[0], bound_out_ref[0]) and equal(bound_out[1], bound_out_ref[1]) + self.assertEqual(bound_out[0], bound_out_ref[0]) + self.assertEqual(bound_out[1], bound_out_ref[1]) + if __name__ == '__main__': # Change to generate=True when genearting reference results testcase = TestBoundOp(generate=False) - testcase.test() \ No newline at end of file + testcase.test() diff --git a/tests/test_branching_heuristics.py b/tests/test_branching_heuristics.py new file mode 100644 index 0000000..29eb5df --- /dev/null +++ b/tests/test_branching_heuristics.py @@ -0,0 +1,83 @@ +import sys +import torch +from types import SimpleNamespace + +sys.path.insert(0, '../complete_verifier') + +from heuristics.base import RandomNeuronBranching + + +def test_branching_heuristics(): + import random + import numpy as np + seed = 123 + torch.manual_seed(seed) + random.seed(seed) + np.random.seed(seed) + + net = SimpleNamespace() + branching_heuristic = RandomNeuronBranching(net) + + for _ in range(10000): + batch_size = random.randint(1, 5) + # Number of layers, and we will split the total_layers into this + # many of layers. + n_layers = random.randint(1, 5) + total_len = random.randint(n_layers, 100) + net.split_nodes = [] + net.split_activations = {} + for i in range(n_layers): + layer = SimpleNamespace() + layer.name = i + activation = SimpleNamespace() + activation.name = f'{i}_activation' + net.split_nodes.append(layer) + net.split_activations[layer.name] = [(activation, 0)] + # Total number of neurons in all layers. + topk = random.randint(1, total_len) + # Generate random and unique scores. + # scores = torch.argsort(torch.rand(batch_size, total_len)) + 1 + scores = torch.rand(batch_size, total_len) + 1e-8 + # Generate random mask. Mask = 1 means this neuron can be split. + masks = (torch.rand(batch_size, total_len) > 0.75).float() + # Generate random split locations. + split_position = torch.randint( + low=0, high=total_len, size=(n_layers - 1,)).sort().values + print(f'testing batch={batch_size}, n_layers={n_layers}, ' + f'total_len={total_len}, topk={topk}, split={split_position}') + segment_lengths = (torch.cat( + [split_position, torch.full(size=(1,), + fill_value=total_len, + device=split_position.device)]) + - torch.cat([torch.zeros((1,), device=split_position.device), + split_position])) + segment_lengths = segment_lengths.int().tolist() + # Cap to the minimum number of valid neurons in each batch. + min_k = int(masks.sum(dim=1).min().item()) + # Find the topk scores and indices across all layers. + topk_scores, topk_indices = (scores * masks).topk(k=min(min_k, topk)) + # Map the indices to groundtruth layer number. + topk_layers = torch.searchsorted( + split_position, topk_indices, right=True) + # Map the indices to groundtruth neuron number. + topk_neurons = topk_indices - torch.cat( + [torch.zeros(1, device=split_position.device, dtype=torch.int64), + split_position] + ).view(1, -1).repeat(batch_size, 1).gather( + dim=1, index=topk_layers) + # Split into a list of scores for testing. + all_layer_scores = scores.split(segment_lengths, dim=1) + all_layer_masks = masks.split(segment_lengths, dim=1) + all_layer_scores = {i: item for i, item in enumerate(all_layer_scores)} + all_layer_masks = {i: item for i, item in enumerate(all_layer_masks)} + branching_heuristic.update_batch_size_and_device(all_layer_scores) + (calculated_layers, calculated_neurons, + calculated_scores) = branching_heuristic.find_topk_scores( + all_layer_scores, all_layer_masks, k=topk, return_scores=True) + torch.testing.assert_close(calculated_layers, topk_layers) + torch.testing.assert_close(calculated_neurons, topk_neurons) + torch.testing.assert_close(calculated_scores, topk_scores) + + +if __name__ == "__main__": + test_branching_heuristics() diff --git a/tests/test_conv.py b/tests/test_conv.py index b9e571f..6349e6d 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -1,18 +1,9 @@ import torch -import os import torch.nn as nn -import torch.nn.functional as F -import torchvision from auto_LiRPA import BoundedModule, BoundedTensor -from auto_LiRPA.perturbations import * +from auto_LiRPA.perturbations import * from testcase import TestCase -class Flatten(nn.Module): - def __init__(self): - super(Flatten, self).__init__() - - def forward(self, x): - return x.view((x.shape[0], -1)) class cnn_model(nn.Module): def __init__(self, layers, padding, stride, linear=True): @@ -26,7 +17,7 @@ def __init__(self, layers, padding, stride, linear=True): length = (length + 2 * padding - 4)//stride + 1 assert length > 0 self.module_list.append(nn.ReLU()) - self.module_list.append(Flatten()) + self.module_list.append(nn.Flatten()) if linear: self.module_list.append(nn.Linear(3 * length * length, 256)) self.module_list.append(nn.Linear(256, 10)) @@ -36,9 +27,9 @@ def forward(self, x): x = self.model(x) return x -class TestConv(TestCase): +class TestConv(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, + super().__init__(methodName, seed=1, ref_path=None, generate=generate) @@ -72,13 +63,13 @@ def test(self): lb_ref, ub_ref = model.compute_bounds() if linear: - assert lb.shape == ub.shape == torch.Size((N, n_classes)) + assert lb.shape == ub.shape == torch.Size((N, n_classes)) self.assertEqual(lb, lb_ref) self.assertEqual(ub, ub_ref) if not linear and layer_num == 1: pred = model(image) - lb_forward, ub_forward = model.compute_bounds(method='forward') + lb_forward, ub_forward = model.compute_bounds(method='forward') self.assertEqual(lb, lb_forward) self.assertEqual(ub, ub_forward) diff --git a/tests/test_conv1d.py b/tests/test_conv1d.py new file mode 100644 index 0000000..3ff15ea --- /dev/null +++ b/tests/test_conv1d.py @@ -0,0 +1,125 @@ +"""Test Conv1d.""" + +from collections import defaultdict +import torch +import os +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import * +from auto_LiRPA.utils import Flatten +from testcase import TestCase + + +class Model(nn.Module): + def __init__(self, kernel_size=2, stride=1, padding=0, in_features=1,out_features=1): + super(Model, self).__init__() + self.n_n_conv1d_1 = nn.Conv1d(**{'groups': 1, 'dilation': 1, 'out_channels': 1, 'padding': padding, 'kernel_size': kernel_size, 'stride': stride, 'in_channels': 1, 'bias': True}) + self.n_n_conv1d_2 = nn.Conv1d(**{'groups': 1, 'dilation': 1, 'out_channels': 1, 'padding': padding, 'kernel_size': kernel_size, 'stride': stride, 'in_channels': 1, 'bias': True}) + self.relu_2 = nn.ReLU() + self.n_n_conv1d_3 = nn.Conv1d(**{'groups': 1, 'dilation': 1, 'out_channels': 1, 'padding': padding, 'kernel_size': kernel_size, 'stride': stride, 'in_channels': 1, 'bias': True}) + self.relu_3 = nn.ReLU() + self.n_n_activation_Flatten = nn.Flatten(**{'start_dim': 1}) + L_in,dialation = in_features,1 + L_out_1 = math.floor((L_in+2*padding-dialation*(kernel_size-1)-1)/stride+1) + L_out_2 = math.floor((L_out_1+2*padding-dialation*(kernel_size-1)-1)/stride+1) + L_out_3 = math.floor((L_out_2+2*padding-dialation*(kernel_size-1)-1)/stride+1) + self.n_n_linear = nn.Linear(**{'in_features':L_out_3, 'out_features':out_features,'bias':True}) + + def forward(self, *inputs,debug=False): + t_ImageInputLayer, = inputs + t_conv1d_1 = self.n_n_conv1d_1(t_ImageInputLayer) + if debug: print("t_ImageInputLayer",t_ImageInputLayer.shape) + if debug: print("t_conv1d_1",t_conv1d_1.shape) + t_conv1d_relu_1 = F.relu(t_conv1d_1) + t_conv1d_2 = self.n_n_conv1d_2(t_conv1d_relu_1) + if debug: print("t_conv1d_2",t_conv1d_2.shape) + t_conv1d_relu_2 = F.relu(t_conv1d_2) + t_conv1d_3 = self.n_n_conv1d_3(t_conv1d_relu_2) + if debug: print("t_conv1d_3",t_conv1d_3.shape) + t_conv1d_relu_3 = F.relu(t_conv1d_3) + t_flatten = self.n_n_activation_Flatten(t_conv1d_relu_3) + if debug: print("t_flatten",t_flatten.shape) + t_linear = self.n_n_linear(t_flatten) + if debug: print("t_linear",t_linear.shape) + return t_linear + +class TestConv1D(TestCase): + def __init__(self, methodName='runTest', generate=False): + super().__init__(methodName, + seed=1, ref_path=None, + generate=generate) + + def test(self): + np.random.seed(123) + + N = 3 + C = 1 + M = 173 + n_classes = 2 + for kernel_size in [3,4]: + for padding in [0,1]: + for stride in [2,3]: + print(kernel_size, padding, stride) + + model_ori = Model(kernel_size=kernel_size, padding=padding, stride=stride, in_features=M,out_features=n_classes) + if not self.generate: + data = torch.load('data/conv1d_test_data_{}-{}-{}'.format(kernel_size, padding, stride)) + image = data['input'] + model_ori(image) + model_ori.load_state_dict(data['model']) + else: + image = torch.rand([N, C, M]) + model_ori(image) + + + conv_mode = "matrix" + + model = BoundedModule(model_ori, image, device="cpu", bound_opts={"conv_mode": conv_mode}) + eps = 0.3 + norm = np.inf + ptb = PerturbationLpNorm(norm=norm, eps=eps) + image_clean = image.detach().clone().requires_grad_(requires_grad=True) + output_clean = model_ori(image_clean) + image = BoundedTensor(image, ptb) + pred = model(image) + lb, ub,A = model.compute_bounds(return_A=True,needed_A_dict={model.output_name[0]:model.input_name[0]},) + ''' + # 1. testing if lb == ub == pred when eps = 0 + assert (lb == ub).all() and torch.allclose(lb,pred,rtol=1e-5) and torch.allclose(ub,pred,rtol=1e-5) + # 2. test if A matrix equals to gradient of the input + # get output's grad with respect to the input without iterating through torch.autograd.grad: + # https://stackoverflow.com/questions/64988010/getting-the-outputs-grad-with-respect-to-the-input + uA = A[model.output_name[0]][model.input_name[0]]['uA'] + lA = A[model.output_name[0]][model.input_name[0]]['lA'] + assert (uA==lA).all() + assert (torch.autograd.functional.jacobian(model_ori,image_clean).sum(dim=2)==uA).all() + assert (torch.autograd.functional.jacobian(model_ori,image_clean).sum(dim=2)==lA).all() + # double check + input_grads = torch.zeros(uA.shape) + for i in range(N): + for j in range(n_classes): + input_grads[i][j]=torch.autograd.grad(outputs=output_clean[i,j], inputs=image_clean, retain_graph=True)[0].sum(dim=0) + assert (input_grads==uA).all() + assert (input_grads==lA).all() + ''' + # 3. test when eps = 0.3 (uncommented) + if self.generate: + torch.save( + {'model': model_ori.state_dict(), + 'input': image, + 'lb': lb, + 'ub': ub}, 'data/conv1d_test_data_{}-{}-{}'.format(kernel_size, padding, stride) + ) + + if not self.generate: + lb_ref = data['lb'] + ub_ref = data['ub'] + assert torch.allclose(lb, lb_ref, 1e-3) + assert torch.allclose(ub, ub_ref, 1e-3) + + +if __name__ == '__main__': + testcase = TestConv1D(generate=False) + testcase.test() diff --git a/tests/test_distinct_patches.py b/tests/test_distinct_patches.py index 45f006f..e91c865 100644 --- a/tests/test_distinct_patches.py +++ b/tests/test_distinct_patches.py @@ -1,4 +1,3 @@ -from numpy.core.numeric import allclose import torch import random import numpy as np @@ -9,9 +8,9 @@ from auto_LiRPA.perturbations import * import sys sys.path.append('../examples/vision') -import models from testcase import TestCase + class cnn_4layer_b(nn.Module): def __init__(self, paddingA, paddingB): super().__init__() @@ -38,9 +37,9 @@ def forward(self, x): x = self.linear(x) return self.fc(F.relu(x)) -class TestDistinctPatches(TestCase): +class TestDistinctPatches(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, + super().__init__(methodName, seed=1234, ref_path='data/resnet_patches_test_data', generate=generate) @@ -56,10 +55,10 @@ def test(self): model_ori = cnn_4layer_b(paddingA, paddingB) - conv_mode = 'patches' # conv_mode can be set as 'matrix' or 'patches' - + conv_mode = 'patches' # conv_mode can be set as 'matrix' or 'patches' + normalize = torchvision.transforms.Normalize(mean = [0.4914, 0.4822, 0.4465], std = [0.2023, 0.1994, 0.2010]) - test_data = torchvision.datasets.CIFAR10("./data", train=False, download=True, + test_data = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor(), normalize])) N = 1 n_classes = 10 @@ -106,7 +105,7 @@ def test(self): assert torch.allclose(lb, lb_ref) assert torch.allclose(ub, ub_ref) - + if __name__ == '__main__': # Change to generate=True when genearting reference results diff --git a/tests/test_general_nonlinear.py b/tests/test_general_nonlinear.py new file mode 100644 index 0000000..0ab0eb6 --- /dev/null +++ b/tests/test_general_nonlinear.py @@ -0,0 +1,106 @@ +import sys +import pytest +import torch.nn as nn + +sys.path.insert(0, '../complete_verifier') + +import arguments +from beta_CROWN_solver import LiRPANet +from bab import general_bab + +from auto_LiRPA import BoundedTensor +from auto_LiRPA.perturbations import * + + +class Sin(nn.Module): + def forward(self, x): + return torch.sin(x) + + +def cifar_model_wide(): + # cifar wide + model = nn.Sequential( + nn.Conv2d(3, 16, 4, stride=2, padding=1), + Sin(), + nn.Conv2d(16, 32, 4, stride=2, padding=1), + Sin(), + nn.Flatten(), + nn.Linear(32 * 8 * 8, 100), + Sin(), + nn.Linear(100, 10) + ) + return model + + +def bab(model_ori, data, target, norm, eps, data_max=None, data_min=None): + if norm == np.inf: + if data_max is None: + data_ub = data + eps + data_lb = data - eps + else: + data_ub = torch.min(data + eps, data_max) + data_lb = torch.max(data - eps, data_min) + else: + data_ub = data_lb = data + + pred = torch.argmax(model_ori(data), dim=1) + + c = torch.zeros((1, 1, 10)) # we only support c with shape of (1, 1, n) + c[0, 0, pred] = 1 + c[0, 0, target] = -1 + + arguments.Config.parse_config() + + arguments.Config["solver"]["batch_size"] = 200 + arguments.Config["bab"]["decision_thresh"] = np.float64(10) # naive float obj has no max() function, np.inf will lead infeasible domain + arguments.Config["solver"]["beta-crown"]["iteration"] = 20 + arguments.Config["bab"]["timeout"] = 60 #300 + + arguments.Config["solver"]["alpha-crown"]["lr_alpha"] = 0.1 + arguments.Config["solver"]["beta-crown"]["lr_beta"] = 0.1 + arguments.Config["bab"]["branching"]["method"] = 'nonlinear' + arguments.Config["bab"]["branching"]["candidates"] = 2 + arguments.Config["general"]["enable_incomplete_verification"] = False + arguments.Config["data"]["dataset"] = 'cifar' + + # LiRPA wrapper + model = LiRPANet(model_ori, device='cpu', in_size=(1, 3, 32, 32), c=c) + + if list(model.net.parameters())[0].is_cuda: + data = data.cuda() + data_lb, data_ub = data_lb.cuda(), data_ub.cuda() + + ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub) + x = BoundedTensor(data, ptb).to(data_lb.device) + domain = torch.stack([data_lb.squeeze(0), data_ub.squeeze(0)], dim=-1) + forward = model_ori(x) + + min_lb = general_bab( + model, domain, x, rhs=arguments.Config["bab"]["decision_thresh"])[0] + + if isinstance(min_lb, torch.Tensor): + min_lb = min_lb.item() + + min_lb += arguments.Config["bab"]["decision_thresh"] + print(min_lb) + + assert min_lb < torch.min(forward) + +# This test takes long time so it is set as the last test case. +@pytest.mark.skip(reason="The test is failing now after removing index clamping.") +# @pytest.mark.order(-1) +def test(): + model_ori = cifar_model_wide() + data = torch.load('data/beta_crown_test_data') + model_ori.load_state_dict(data['state_dict']) + x = data['x'] + pidx = data['pidx'] + eps_temp = data['eps_temp'] + data_max = data['data_max'] + data_min = data['data_min'] + + bab(model_ori, x, pidx, float('inf'), eps_temp, data_max=data_max, data_min=data_min) + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/tests/test_identity.py b/tests/test_identity.py index dfe6355..b3fcd67 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -3,9 +3,9 @@ import torch.nn as nn from auto_LiRPA import BoundedModule, BoundedTensor from auto_LiRPA.perturbations import * -from testcase import TestCase +from testcase import TestCase -class TestIdentity(TestCase): +class TestIdentity(TestCase): def __init__(self, methodName='runTest'): super().__init__(methodName) @@ -18,10 +18,11 @@ def test(self): ptb = PerturbationLpNorm(norm=np.inf, eps=eps) x = BoundedTensor(x, ptb) y_l, y_u = model.compute_bounds() - self.assertTensorEqual(x, y) - self.assertTensorEqual(y_l, x - eps) - self.assertTensorEqual(y_u, x + eps) + self.assert_tensor_equal(x, y) + self.assert_tensor_equal(y_l, x - eps) + self.assert_tensor_equal(y_u, x + eps) + if __name__ == '__main__': testcase = TestIdentity() - testcase.test() + testcase.test() diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index d887539..e762b14 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -1,38 +1,13 @@ +# pylint: disable=wrong-import-position """Test Jacobian bounds.""" + +import sys +sys.path.append('../examples/vision') +from jacobian import compute_jacobians import torch import torch.nn as nn -import torch.nn.functional as F -from auto_LiRPA import BoundedModule, BoundedTensor -from auto_LiRPA.perturbations import * from testcase import TestCase - - -class MLP(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(16**2, 20) - self.fc2 = nn.Linear(20, 10) - - def forward(self, x): - x = torch.flatten(x, -1) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x - - -class CNN(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 2, 7, stride=1, padding=0) - self.conv2 = nn.Conv2d(2, 3, 7, stride=1, padding=0) - self.fc1 = nn.Linear(48, 20) - self.fc2 = nn.Linear(20, 10) - - def forward(self, x): - x = F.relu(self.conv1(x)) - x = F.relu(self.conv2(x)) - x = x.view(x.size(0), -1) - return self.fc2(F.relu(self.fc1(x))) +from auto_LiRPA.utils import Flatten class TestJacobian(TestCase): @@ -42,18 +17,20 @@ def __init__(self, methodName='runTest', generate=False): generate=generate) def test(self): - image = torch.randn(1, 1, 16, 16) - model = CNN() - model = BoundedModule(model, image, device='cpu') - ptb = PerturbationLpNorm(eps=0.1) - x = BoundedTensor(image, ptb) - output = model(x) - print(output) - model.augment_gradient_graph(x) - ret = model.compute_jacobian_bounds( - x, labels=torch.tensor([1], dtype=torch.long)) - print(ret) - self.result = [ret] + in_dim, width, linear_size = 8, 2, 8 + model = nn.Sequential( + nn.Conv2d(3, width, 3, stride=1, padding=0), + nn.ReLU(), + nn.Conv2d(width, width, 3, stride=1, padding=0), + nn.ReLU(), + Flatten(), + nn.Linear(width * (in_dim-4)**2, linear_size), + nn.ReLU(), + nn.Linear(linear_size, 10) + ) + x0 = torch.randn(1, 3, in_dim, in_dim) + self.result = compute_jacobians( + model, x0, bound_opts={'optimize_bound_args': {'iteration': 2}}) self.check() diff --git a/tests/test_language_models.py b/tests/test_language_models.py index aa4f6aa..a5ab12e 100644 --- a/tests/test_language_models.py +++ b/tests/test_language_models.py @@ -3,14 +3,14 @@ import argparse import pickle import torch -import numpy as np -import pytest from auto_LiRPA.utils import logger parser = argparse.ArgumentParser() parser.add_argument('--gen_ref', action='store_true', help='generate reference results') parser.add_argument('--train', action='store_true', help='pre-train the models') -args, unknown = parser.parse_known_args() +parser.add_argument('--keep_results', action='store_true', help='keep intermediate results.') +parser.add_argument('--load_results', action='store_true', help='load intermediate results without reruning.') +args, unknown = parser.parse_known_args() def prepare_data(): os.system('cd ../examples/language;\ @@ -59,6 +59,10 @@ def read_res(): return pickle.load(file) def evaluate(): + if args.load_results: + print("loading intermediate results...") + with open("./tmp_language_results.pkl", "rb") as file: + return pickle.load(file) logger.info('\nEvaluating the trained LSTM') print(cmd_lstm_test) print() @@ -68,8 +72,12 @@ def evaluate(): print(cmd_transformer_test) print() os.system(cmd_transformer_test) - res_transformer = read_res() + res_transformer = read_res() os.system("rm {}".format(res_path)) + if args.keep_results: + with open("./tmp_language_results.pkl", "wb") as file: + pickle.dump((res_transformer, res_lstm), file) + print("intermediate results saved.") return res_transformer, res_lstm def gen_ref(): diff --git a/tests/test_min_max.py b/tests/test_min_max.py new file mode 100644 index 0000000..21dcf6b --- /dev/null +++ b/tests/test_min_max.py @@ -0,0 +1,75 @@ +import os +from collections import defaultdict +import torch +import torch.nn as nn +import torchvision +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import PerturbationLpNorm +from auto_LiRPA.utils import * +from testcase import TestCase + +class Test_Model(nn.Module): + def __init__(self): + super(Test_Model, self).__init__() + + self.seq1 = nn.Sequential( + nn.Conv2d(1, 16, 4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(16, 32, 4, stride=2, padding=1) + ) + + self.seq2 = nn.Sequential( + nn.Conv2d(1, 16, 4, stride=2, padding=1), + nn.ReLU(), + nn.Conv2d(16, 32, 4, stride=2, padding=1) + ) + + self.seq3 = nn.Sequential( + nn.Conv2d(32, 8, 2, stride=2, padding=1), + nn.ReLU(), + Flatten(), + nn.Linear(8*4*4,100), + nn.ReLU(), + nn.Linear(100, 10) + ) + + def forward(self, x): + return self.seq3(torch.max(self.seq1(x), self.seq2(x))) + +class TestMinMax(TestCase): + def __init__(self, methodName='runTest', generate=False): + super().__init__(methodName, + seed=1, ref_path='data/min_max_test_data', generate=generate) + + def test(self): + for conv_mode in ['patches', 'matrix']: + model = Test_Model() + checkpoint = torch.load( + os.path.join(os.path.dirname(__file__), '../examples/vision/pretrained/test_min_max.pth'), + map_location=torch.device('cpu')) + model.load_state_dict(checkpoint) + + test_data = torchvision.datasets.MNIST( + './data', train=False, download=True, + transform=torchvision.transforms.ToTensor()) + + N = 2 + image = test_data.data[:N].view(N,1,28,28) + image = image.to(torch.float32) / 255.0 + + lirpa_model = BoundedModule(model, torch.empty_like(image), device=image.device, bound_opts={"conv_mode": conv_mode}) + + eps = 0.3 + ptb = PerturbationLpNorm(eps = eps) + image = BoundedTensor(image, ptb) + + lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}}) + lb, ub = lirpa_model.compute_bounds(x=(image,), method='CROWN-Optimized') + + self.result = (lb, ub) + self.setUp() + self.check() + +if __name__ == "__main__": + testcase = TestMinMax(generate=False) + testcase.test() \ No newline at end of file diff --git a/tests/test_save_intermediate.py b/tests/test_save_intermediate.py new file mode 100644 index 0000000..571101d --- /dev/null +++ b/tests/test_save_intermediate.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import * +from testcase import TestCase + + +class test_model(nn.Module): + def __init__(self): + super(test_model, self).__init__() + self.model = nn.Sequential( + nn.Flatten(), + nn.Linear(3 * 32 * 32, 1000), + nn.Sigmoid(), + nn.Linear(1000, 500), + nn.Linear(500, 200), + nn.Linear(200, 100), + nn.ReLU(), + nn.Linear(100, 10) + ) + + def forward(self, x): + x = self.model(x) + return x + +class TestSave(TestCase): + def __init__(self, methodName='runTest'): + super().__init__(methodName) + + def test(self, gen_ref=False): + image = torch.randn(1, 3, 32, 32) + image = image.to(torch.float32) / 255.0 + model = test_model() + + bounded_model = BoundedModule( + model, image, bound_opts={ + 'optimize_bound_args': {'iteration': 2}, + }) + + ptb = PerturbationLpNorm(eps=3/255) + x = BoundedTensor(image, ptb) + bounded_model.compute_bounds(x=(x,), method='CROWN-Optimized') + save_dict = bounded_model.save_intermediate(save_path='data/test_save_data' if gen_ref else None) + ref_dict = torch.load('data/test_save_data') + + for node in ref_dict.keys(): + assert torch.allclose(ref_dict[node][0], save_dict[node][0], atol=1e-5) + assert torch.allclose(ref_dict[node][1], save_dict[node][1], atol=1e-5) + +if __name__ == '__main__': + testcase = TestSave() + testcase.test() diff --git a/tests/test_simple_verification.py b/tests/test_simple_verification.py index 8f4fb72..6b05b38 100644 --- a/tests/test_simple_verification.py +++ b/tests/test_simple_verification.py @@ -48,7 +48,7 @@ def test(self): method = 'CROWN-Optimized (alpha-CROWN)' lirpa_model.set_bound_opts({'optimize_bound_args': {'iteration': 20, 'lr_alpha': 0.1}}) _, ub = lirpa_model.compute_bounds(x=(image,), method=method.split()[0]) - self.assertTensorEqual(ub[0][7], torch.tensor(12.5080)) + self.assert_tensor_equal(ub[0][7], torch.tensor(12.5080)) if __name__ == '__main__': testcase = TestSimpleVerification() diff --git a/tests/test_upsample.py b/tests/test_upsample.py new file mode 100644 index 0000000..30a10b1 --- /dev/null +++ b/tests/test_upsample.py @@ -0,0 +1,226 @@ +from collections import defaultdict + +from torch import nn +from auto_LiRPA import BoundedModule, BoundedTensor +from auto_LiRPA.perturbations import * + +from testcase import TestCase + +class Model(nn.Module): + + def __init__(self, + input_dim=5, image_size=4, + scale_factor=2, conv_kernel_size=3, stride=1, padding=1, + conv_in_channels=16, conv_out_channels=4): + super(Model, self).__init__() + self.conv_in_channels = conv_in_channels + self.input_dim = input_dim + self.image_size = image_size + + self.fc1 = nn.Linear(input_dim, conv_in_channels * image_size * image_size) + self.upsample = nn.Upsample(scale_factor=(scale_factor, scale_factor), mode='nearest') + # H = W = 4 * scale_factor now + self.conv1 = nn.Conv2d(in_channels=conv_in_channels, out_channels=conv_out_channels, + kernel_size=(conv_kernel_size, conv_kernel_size), stride=(stride, stride), padding=padding) + # H = W = (4 * scale + 2 * pad - ker + s) // s + size_after_conv = (4 * scale_factor + 2 * padding - conv_kernel_size + stride) // stride + assert size_after_conv > 0, "0 size after convolution, please use more padding, more scale_factor," \ + "smaller kernel, or smaller stride" + self.relu = nn.ReLU() + self.flatten = nn.Flatten() + self.fc2 = nn.Linear(size_after_conv * size_after_conv * conv_out_channels, 1) + # self.sigmoid = nn.Sigmoid() + + def forward(self, input_z): + f1 = self.fc1(input_z) + d1 = f1.reshape(-1, self.conv_in_channels, self.image_size, self.image_size) + d2 = self.upsample(d1) + d3 = self.conv1(d2) + d4 = self.relu(d3) + f2 = self.flatten(d4) + f3 = self.fc2(f2) + # out = self.sigmoid(f3) + return f3 + +class ModelReducedCGAN(nn.Module): + def __init__(self): + """ + The network has the same architecture with merged bn CGAN upsampling one except reduced channel nums + """ + super(ModelReducedCGAN, self).__init__() + self.fc1 = nn.Linear(5, 32) + self.up1 = nn.Upsample(scale_factor=2, mode='nearest') + self.conv1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1) + self.relu1 = nn.ReLU() + self.up2 = nn.Upsample(scale_factor=2, mode='nearest') + self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3, stride=1, padding=1) + self.relu2 = nn.ReLU() + self.up3 = nn.Upsample(scale_factor=2, mode='nearest') + self.conv3 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=1, padding=1) + self.relu3 = nn.ReLU() + self.conv4 = nn.Conv2d(in_channels=4, out_channels=2, kernel_size=3, stride=1, padding=1) + self.conv5 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=3, stride=2, padding=1) + self.relu4 = nn.ReLU() + self.conv6 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1) + self.relu5 = nn.ReLU() + self.conv7 = nn.Conv2d(in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1) + self.relu6 = nn.ReLU() + self.conv8 = nn.Conv2d(in_channels=4, out_channels=4, kernel_size=3, stride=2, padding=1) + self.relu7 = nn.ReLU() + self.fc2 = nn.Linear(4 * 2 * 2, 1) + self.sigmoid = nn.Sigmoid() + + def forward(self, input_z): + f1 = self.fc1(input_z) + f2 = f1.reshape(-1, 2, 4, 4) + f3 = self.up1(f2) + f4 = self.conv1(f3) + f5 = self.relu1(f4) + f6 = self.up2(f5) + f7 = self.conv2(f6) + f8 = self.relu2(f7) + f9 = self.up3(f8) + f10 = self.conv3(f9) + f11 = self.relu3(f10) + f12 = self.conv4(f11) + f13 = self.conv5(f12) + f14 = self.relu4(f13) + f15 = self.conv6(f14) + f16 = self.relu5(f15) + f17 = self.conv7(f16) + f18 = self.relu6(f17) + f19 = self.conv8(f18) + f20 = self.relu7(f19) + f21 = f20.reshape(f20.shape[0], -1) + f22 = self.fc2(f21) + # f23 = self.sigmoid(f22) + return f22 + + + +def recursive_allclose(a, b: dict, verbose=False, prefix=''): + """ + Recursively check whether every corresponding tensors in two dicts are close + :param a: dict a + :param b: dict b + :param prefix: reserved for path tracking in recursive calling for error printing + :return: bool: all_close or not + """ + tot_tensor = 0 + tot_dict = 0 + for k in a: + if isinstance(a[k], torch.Tensor): + if k == 'unstable_idx': continue + if verbose: + print(f'recursive_allclose(): Checking {prefix}{k}') + assert k in b and isinstance(b[k], torch.Tensor) or isinstance(b[k], Patches), f'recursive_allclose(): Tensor not found in path {prefix}{k}' + if isinstance(b[k], torch.Tensor): + assert torch.allclose(a[k].reshape(-1), b[k].reshape(-1), 1e-4, 1e-5), f'recursive_allclose(): Inconsistency found in path {prefix}{k}' + tot_tensor += 1 + elif isinstance(a[k], dict): + assert k in b and isinstance(b[k], dict), f'recursive_allclose(): dict not found in path {prefix}{k}' + recursive_allclose(a[k], b[k], verbose, prefix + k) + tot_dict += 1 + tot_b_tensor = sum([1 if isinstance(v, torch.Tensor) or isinstance(v, Patches) and k != 'unstable_idx' else 0 for k, v in b.items()]) + tot_b_dict = sum([1 if isinstance(v, dict) else 0 for v in b.values()]) + assert tot_tensor == tot_b_tensor, f'recursive_allclose(): Extra tensors found in path {prefix}' + assert tot_dict == tot_b_dict, f'recursive_allclose(): Extra recursive paths found in path {prefix}' + return True + + +class TestUpSample(TestCase): + def __init__(self, methodName='runTest', generate=False, device='cpu'): + super().__init__(methodName, seed=1, ref_path=None, generate=generate) + self.device = device + + def test(self, seed=123): + for kernel_size in [3,5]: + for scaling_factor in [2,3,4]: + for stride in [1,2]: + for padding in [1]: + self.test_instance(kernel_size, scaling_factor, stride, padding, seed=seed) + + def test_instance(self, kernel_size=3, scaling_factor=2, stride=1, padding=1, seed=123): + self.set_seed(seed) + + print(f'kernel_size = {kernel_size}, scaling_factor = {scaling_factor}, stride = {stride}, padding = {padding}') + random_input = torch.randn((1,5)).to(torch.device(self.device)) * 1000. + eps = 0.3 + + model_ori = Model(scale_factor=scaling_factor, + conv_kernel_size=kernel_size, + stride=stride, + padding=padding) + + ptb = PerturbationLpNorm(norm=np.inf, eps=eps) + z1_clean = random_input.detach().clone().requires_grad_(requires_grad=True) + + z1 = BoundedTensor(random_input, ptb) + model_mat = BoundedModule(model_ori, (random_input,), device=self.device, + bound_opts={"conv_mode": "matrix"}) + pred_of_mat = model_mat(z1) + lb_m, ub_m, A_m = model_mat.compute_bounds(return_A=True, needed_A_dict={model_mat.output_name[0]: model_mat.input_name[0]}, ) + + model_pat = BoundedModule(model_ori, (random_input,), device=self.device, + bound_opts={"conv_mode": "patches"}) + pred_of_patch = model_pat(z1) + lb_p, ub_p, A_p = model_pat.compute_bounds(return_A=True, needed_A_dict={ + model_pat.output_name[0]: model_pat.input_name[0]}, ) + + assert torch.allclose(pred_of_mat, pred_of_patch, 1e-5) + assert torch.allclose(lb_m, lb_p, 1e-5) + assert torch.allclose(ub_m, ub_p, 1e-5) + assert recursive_allclose(A_m, A_p, verbose=True) + +class TestReducedCGAN(TestCase): + + def __init__(self, methodName='runTest', generate=False, device='cpu'): + super().__init__(methodName, seed=1, ref_path=None, generate=generate) + self.device = device + + def test(self, seed=456): + self.set_seed(seed) + input = torch.tensor([[0.583, -0.97, -0.97, 0.598, 0.737]]).to(torch.device(self.device)) + eps = 0.1 + + model_ori = ModelReducedCGAN() + + ptb = PerturbationLpNorm(norm=np.inf, eps=eps) + z1_clean = input.detach().clone().requires_grad_(requires_grad=True) + + z1 = BoundedTensor(input, ptb) + model_mat = BoundedModule(model_ori, (input,), device=self.device, + bound_opts={"conv_mode": "matrix"}) + pred_of_mat = model_mat(z1) + + needed_A_dict = defaultdict(set) + for node in model_mat.nodes(): + needed_A_dict[node.name] = set() + + lb_m, ub_m, A_m = model_mat.compute_bounds((z1,), return_A=True, needed_A_dict=needed_A_dict, method='crown') + + model_pat = BoundedModule(model_ori, (input,), device=self.device, + bound_opts={"conv_mode": "patches", "sparse_features_alpha": False}) + pred_of_patch = model_pat(z1) + lb_p, ub_p, A_p = model_pat.compute_bounds((z1,), return_A=True, needed_A_dict=needed_A_dict, method='crown') + + # print(pred_of_mat, pred_of_patch) + assert torch.allclose(pred_of_mat, pred_of_patch, 1e-5) + assert torch.allclose(lb_m, lb_p, 1e-5) + assert torch.allclose(ub_m, ub_p, 1e-5) + assert recursive_allclose(A_m, A_p, verbose=True) + +if __name__ == '__main__': + # should use device = 'cpu' for GitHub CI + testcase = TestUpSample(generate=False, device='cpu') + testcase.test(seed=123) + + # """ + # following test is much stronger, but runs within 30s only on GPUs + # so commented it out for CI testing now + # required GPU memory: 1.5 GiB + # """ + # testhardcase = TestReducedCGAN(generate=False, device='cuda') + # testhardcase.test(seed=456) + + diff --git a/tests/test_vision_models.py b/tests/test_vision_models.py index 7a5019a..c3e8f6a 100644 --- a/tests/test_vision_models.py +++ b/tests/test_vision_models.py @@ -26,7 +26,7 @@ def forward(self, x): class TestVisionModels(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, seed=1234, + super().__init__(methodName, seed=1234, ref_path='data/vision_test_data', generate=generate) self.result = {} @@ -40,8 +40,8 @@ def verify_bounds(self, model, x, IBP, method, forward_ret, lb_name, ub_name): grad = x.grad self.result[lb_name[:-2] + 'grad'] = grad.clone() if not self.generate: - assert torch.allclose(lb, self.reference[lb_name], 1e-4), (lb - self.reference[lb_name]).abs().sum() - assert torch.allclose(ub, self.reference[ub_name], 1e-4), (ub - self.reference[ub_name]).abs().sum() + assert torch.allclose(lb, self.reference[lb_name], 1e-4, atol=2e-7), (lb - self.reference[lb_name]).abs().max() + assert torch.allclose(ub, self.reference[ub_name], 1e-4, atol=2e-7), (ub - self.reference[ub_name]).abs().max() assert ((lb - self.reference[lb_name]).pow(2).sum() < 1e-9), (lb - self.reference[lb_name]).pow(2).sum() assert ((ub - self.reference[ub_name]).pow(2).sum() < 1e-9), (ub - self.reference[ub_name]).pow(2).sum() assert torch.allclose(grad, self.reference[lb_name[:-2] + 'grad'], 1e-4, 1e-6) diff --git a/tests/test_weight_perturbation.py b/tests/test_weight_perturbation.py index 3cc7ba0..a07f29f 100644 --- a/tests/test_weight_perturbation.py +++ b/tests/test_weight_perturbation.py @@ -11,7 +11,9 @@ class TestWeightPerturbation(TestCase): def __init__(self, methodName='runTest', generate=False): - super().__init__(methodName, seed=1234, ref_path='data/weight_perturbation_test_data') + super().__init__( + methodName, seed=1234, + ref_path='data/weight_perturbation_test_data', generate=generate) self.result = {} def test_training(self): @@ -32,20 +34,22 @@ def verify_bounds(self, model, x, IBP, method, forward_ret, lb_name, ub_name): self.result[lb_name] = lb.detach().data.clone() self.result[ub_name] = ub.detach().data.clone() - assert torch.allclose(self.reference[lb_name], self.result[lb_name], 1e-4, 1e-6) - assert torch.allclose(self.reference[ub_name], self.result[ub_name], 1e-4, 1e-6) - assert ((self.reference[lb_name] - self.result[lb_name]).pow(2).sum() < 1e-8) - assert ((self.reference[ub_name] - self.result[ub_name]).pow(2).sum() < 1e-8) - # test gradient backward propagation loss = (ub - lb).abs().sum() loss.backward() - # gradient w.r.t input only grad = x.grad self.result[lb_name+'_grad'] = grad.detach().data.clone() - assert torch.allclose(self.reference[lb_name+'_grad'], self.result[lb_name + '_grad'], 1e-4, 1e-6) - assert ((self.reference[lb_name + '_grad'] - self.result[lb_name + '_grad']).pow(2).sum() < 1e-8) + + if not self.generate: + assert torch.allclose(self.reference[lb_name], self.result[lb_name], 1e-4, 1e-6) + assert torch.allclose(self.reference[ub_name], self.result[ub_name], 1e-4, 1e-6) + assert ((self.reference[lb_name] - self.result[lb_name]).pow(2).sum() < 1e-8) + assert ((self.reference[ub_name] - self.result[ub_name]).pow(2).sum() < 1e-8) + assert torch.allclose(self.reference[lb_name+'_grad'], + self.result[lb_name + '_grad'], 1e-4, 1e-6) + assert ((self.reference[lb_name + '_grad'] + - self.result[lb_name + '_grad']).pow(2).sum() < 1e-8) def test_perturbation(self): np.random.seed(123) # FIXME This seed is inconsistent with other seeds (1234) @@ -91,7 +95,7 @@ def verify_model(pert_weight=True, pert_bias=True, norm=np.inf, lb_name='', ub_n self.save() if __name__ == '__main__': - testcase = TestWeightPerturbation() + testcase = TestWeightPerturbation(generate=False) testcase.setUp() testcase.test_perturbation() testcase.test_training() diff --git a/tests/testcase.py b/tests/testcase.py index b4dff42..9d5b8a0 100644 --- a/tests/testcase.py +++ b/tests/testcase.py @@ -9,8 +9,8 @@ class TestCase(unittest.TestCase): def __init__(self, methodName='runTest', seed=1, ref_path=None, generate=False): super().__init__(methodName) - self.addTypeEqualityFunc(np.ndarray, 'assertArrayEqual') - self.addTypeEqualityFunc(torch.Tensor, 'assertTensorEqual') + self.addTypeEqualityFunc(np.ndarray, 'assert_array_equal') + self.addTypeEqualityFunc(torch.Tensor, 'assert_tensor_equal') self.set_seed(seed) self.ref_path = ref_path @@ -46,15 +46,25 @@ def check(self): if self.generate: self.save() else: - for i in range(len(self.result)): - self.assertEqual(self.result[i], self.reference[i]) + self.assert_equal(self.result, self.reference) - def assertArrayEqual(self, a, b, msg=None): - self.assertIsInstance(a, np.ndarray, 'First argument is not an np.ndarray') - self.assertIsInstance(b, np.ndarray, 'Second argument is not an np.ndarray') + def assert_equal(self, a, b): + assert type(a) == type(b) + if isinstance(a, list): + for a_, b_ in zip(a, b): + self.assert_equal(a_, b_) + elif isinstance(a, tuple): + for a_, b_ in zip(a, b): + self.assert_equal(a_, b_) + elif isinstance(a, np.ndarray): + self.assert_array_equal(a, b) + elif isinstance(a, torch.Tensor): + self.assert_tensor_equal(a, b) + else: + assert a == b + + def assert_array_equal(self, a, b, msg=None): return np.allclose(a, b) - def assertTensorEqual(self, a, b, msg=None): - self.assertIsInstance(a, torch.Tensor, 'First argument is not an torch.Tensor') - self.assertIsInstance(b, torch.Tensor, 'Second argument is not an torch.Tensor') + def assert_tensor_equal(self, a, b, msg=None): return torch.allclose(a, b)