Explainable Artificial Intelligence (XAI) for Neutral Networks in tensorflow/keras.
With sauron
you can use Explainable Artificial Intelligence (XAI)
methods to understand predictions made by Neural Networks in
tensorflow/keras
. For the time being only Convolutional Neural
Networks are supported, but it will change in time.
You can install the latest version of sauron
with remotes
:
remotes::install_github("maju116/sauron")
(main
branch contains the stable version. Use develop
branch for
latest features)
To install previous versions you can run:
remotes::install_github("maju116/sauron", ref = "0.1.0")
In order to install sauron
you need to install keras
and
tensorflow
packages and Tensorflow
version >= 2.0.0
(Tensorflow 1.x
will not be supported!)
To generate any explanations you will have to create an object of class
CNNexplainer
. To do this you will need two things:
- tensorflow/keras model
- image preprocessing function (optional)
library(tidyverse)
library(sauron)
model <- application_xception()
preprocessing_function <- xception_preprocess_input
explainer <- CNNexplainer$new(model = model,
preprocessing_function = preprocessing_function,
id = "imagenet_xception")
explainer
#> <CNNexplainer>
#> Public:
#> clone: function (deep = FALSE)
#> explain: function (input_imgs_paths, class_index = NULL, methods = c("V",
#> id: imagenet_xception
#> initialize: function (model, preprocessing_function, id = NULL)
#> model: function (object, ...)
#> preprocessing_function: function (x)
#> show_available_methods: function ()
#> Private:
#> available_methods: tbl_df, tbl, data.frame
To see available XAI methods for the CNNexplainer
object use:
explainer$show_available_methods()
#> # A tibble: 8 x 2
#> method name
#> <chr> <chr>
#> 1 V Vanilla gradient
#> 2 GI Gradient x Input
#> 3 SG SmoothGrad
#> 4 SGI SmoothGrad x Input
#> 5 IG Integrated Gradients
#> 6 GB Guided Backpropagation
#> 7 OCC Occlusion Sensitivity
#> 8 GGC Guided Grad-CAM
Now you can explain predictions using explain
method. You will need:
- paths to the images for which you want to generate explanations.
- class indexes for which the explanations should be generated
(optional, if set to
NULL
class that maximizes predicted probability will be found for each image). - character vector with method names (optional, by default explainer will use all methods).
- batch size (optional, by default number of inserted images).
- additional arguments with settings for a specific method (optional).
As an output you will get an object of class CNNexplanations
:
input_imgs_paths <- list.files(system.file("extdata", "images", package = "sauron"), full.names = TRUE)
explanations <- explainer$explain(input_imgs_paths = input_imgs_paths,
class_index = NULL,
batch_size = 1,
methods = c("V", "IG", "GB", "GGC"),
steps = 10, # Number of Integrated Gradients steps
grayscale = FALSE # RGB or Gray gradients
)
explanations
#> CNNexplanations object contains explanations for 3 images for 1 model.
You can get raw explanations and metadata from CNNexplanations
object
using:
explanations$get_metadata()
#> $multimodel_explanations
#> [1] FALSE
#>
#> $ids
#> [1] "imagenet_xception"
#>
#> $n_models
#> [1] 1
#>
#> $target_sizes
#> $target_sizes[[1]]
#> [1] 299 299 3
#>
#>
#> $methods
#> [1] "V" "IG" "GB" "GGC"
#>
#> $input_imgs_paths
#> [1] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/cat_and_dog.jpg"
#> [2] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/cat.jpeg"
#> [3] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/zebras.jpg"
#>
#> $n_imgs
#> [1] 3
raw_explanations <- explanations$get_explanations()
str(raw_explanations)
#> List of 1
#> $ imagenet_xception:List of 5
#> ..$ Input: num [1:3, 1:299, 1:299, 1:3] 147 134 170 147 134 168 144 134 170 144 ...
#> .. ..- attr(*, "dimnames")=List of 4
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> ..$ V : int [1:3, 1:299, 1:299, 1:3] 0 0 0 0 0 0 0 0 0 0 ...
#> .. ..- attr(*, "dimnames")=List of 4
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> ..$ IG : int [1:3, 1:299, 1:299, 1:3] 0 0 0 0 0 0 0 0 0 0 ...
#> .. ..- attr(*, "dimnames")=List of 4
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> ..$ GB : int [1:3, 1:299, 1:299, 1:3] 0 0 2 0 0 111 0 0 28 0 ...
#> .. ..- attr(*, "dimnames")=List of 4
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> ..$ GGC : num [1:3, 1:299, 1:299, 1] 7.13e-05 0.00 4.55e-04 7.13e-05 0.00 ...
#> .. ..- attr(*, "dimnames")=List of 4
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
#> .. .. ..$ : NULL
To visualize and save generated explanations use:
explanations$plot_and_save(combine_plots = TRUE, # Show all explanations side by side on one image?
output_path = NULL, # Where to save output(s)
plot = TRUE # Should output be plotted?
)
If you want to compare two or more different models you can do it by
combining CNNexplainer
objects into CNNexplainers
object:
model2 <- application_densenet121()
preprocessing_function2 <- densenet_preprocess_input
explainer2 <- CNNexplainer$new(model = model2,
preprocessing_function = preprocessing_function2,
id = "imagenet_densenet121")
model3 <- application_densenet201()
preprocessing_function3 <- densenet_preprocess_input
explainer3 <- CNNexplainer$new(model = model3,
preprocessing_function = preprocessing_function3,
id = "imagenet_densenet201")
explainers <- CNNexplainers$new(explainer, explainer2, explainer3)
explanations123 <- explainers$explain(input_imgs_paths = input_imgs_paths,
class_index = NULL,
batch_size = 1,
methods = c("V", "IG", "GB", "GGC"),
steps = 10,
grayscale = FALSE
)
explanations123$get_metadata()
#> $multimodel_explanations
#> [1] TRUE
#>
#> $ids
#> [1] "imagenet_xception" "imagenet_densenet121" "imagenet_densenet201"
#>
#> $n_models
#> [1] 3
#>
#> $target_sizes
#> $target_sizes[[1]]
#> [1] 299 299 3
#>
#> $target_sizes[[2]]
#> [1] 224 224 3
#>
#> $target_sizes[[3]]
#> [1] 224 224 3
#>
#>
#> $methods
#> [1] "V" "IG" "GB" "GGC"
#>
#> $input_imgs_paths
#> [1] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/cat_and_dog.jpg"
#> [2] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/cat.jpeg"
#> [3] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/zebras.jpg"
#>
#> $n_imgs
#> [1] 3
explanations123$plot_and_save(combine_plots = TRUE,
output_path = NULL,
plot = TRUE
)
Alternatively if you already have some CNNexplanations
objects
generated (for the same images and using same methods) you can combine
them:
explanations2 <- explainer2$explain(input_imgs_paths = input_imgs_paths,
class_index = NULL,
batch_size = 1,
methods = c("V", "IG", "GB", "GGC"),
steps = 10,
grayscale = FALSE
)
explanations3 <- explainer3$explain(input_imgs_paths = input_imgs_paths,
class_index = NULL,
batch_size = 1,
methods = c("V", "IG", "GB", "GGC"),
steps = 10,
grayscale = FALSE
)
explanations$combine(explanations2, explanations3)
explanations$get_metadata()
#> $multimodel_explanations
#> [1] TRUE
#>
#> $ids
#> [1] "imagenet_xception" "imagenet_densenet121" "imagenet_densenet201"
#>
#> $n_models
#> [1] 3
#>
#> $target_sizes
#> $target_sizes[[1]]
#> [1] 299 299 3
#>
#> $target_sizes[[2]]
#> [1] 224 224 3
#>
#> $target_sizes[[3]]
#> [1] 224 224 3
#>
#>
#> $methods
#> [1] "V" "IG" "GB" "GGC"
#>
#> $input_imgs_paths
#> [1] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/cat_and_dog.jpg"
#> [2] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/cat.jpeg"
#> [3] "/home/maju116/R/x86_64-pc-linux-gnu-library/4.0/sauron/extdata/images/zebras.jpg"
#>
#> $n_imgs
#> [1] 3
explanations$plot_and_save(combine_plots = TRUE,
output_path = NULL,
plot = TRUE
)