Skip to content

mnmueller/jax_verify

 
 

Repository files navigation

jax_verify: Neural Network Verification in JAX

tests status docs: latest

Jax_verify is a library containing JAX implementations of many widely-used neural network verification techniques.

Overview

If you just want to get started with using jax_verify to verify your neural networks, the main thing to know is we provide a simple, consistent interface for a variety of verification algorithms:

output_bounds = jax_verify.verification_technique(network_fn, input_bounds)

Here, network_fn is any JAX function, input_bounds define bounds over possible inputs to network_fn, and output_bounds will be the computed bounds over possible outputs of network_fn. verification_technique can be one of many algorithms implemented in jax_verify, such as interval_bound_propagation or crown_bound_propagation.

The overall approach is to use JAX’s powerful program transformation system, which allows us to analyze general network structures defined by network_fn and then to define corresponding functions for calculating verified bounds for these networks.

Verification Techniques

The methods currently provided by jax_verify include:

Installation

Stable: Just run pip install jax_verify and you can import jax_verify from any of your Python code.

Latest: Clone this directory and run pip install . from the directory root.

Getting Started

We suggest starting by looking at the minimal examples in the examples/ directory. For example, all the bound propagation techniques can be run with the run_boundprop.py script:

cd examples/
python3 run_boundprop.py --boundprop_method=interval_bound_propagation

For documentation, please refer to the API reference page.

Notes

Contributions of additional verification techniques are very welcome. Please open an issue first to let us know.

This is not an official Google product.

About

Neural network verification in JAX

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%