From 86dc77f40ce4ff114d42c03464d68270fda823b9 Mon Sep 17 00:00:00 2001 From: giandos200 Date: Thu, 13 Jan 2022 12:38:39 +0100 Subject: [PATCH] Enanchment in CF Generation --- dice_ml/__init__.py | 2 +- dice_ml/counterfactual_explanations.py | 9 ++- .../data_interfaces/private_data_interface.py | 7 +- .../data_interfaces/public_data_interface.py | 10 ++- dice_ml/dice.py | 10 ++- dice_ml/diverse_counterfactuals.py | 7 +- dice_ml/explainer_interfaces/dice_KD.py | 13 ++-- dice_ml/explainer_interfaces/dice_genetic.py | 38 ++++++----- dice_ml/explainer_interfaces/dice_pytorch.py | 10 +-- dice_ml/explainer_interfaces/dice_random.py | 17 +++-- .../explainer_interfaces/dice_tensorflow1.py | 12 ++-- .../explainer_interfaces/dice_tensorflow2.py | 12 ++-- .../explainer_interfaces/explainer_base.py | 67 +++++++------------ .../explainer_interfaces/feasible_base_vae.py | 13 ++-- .../feasible_model_approx.py | 8 +-- dice_ml/model.py | 4 +- dice_ml/model_interfaces/base_model.py | 8 +-- .../keras_tensorflow_model.py | 5 +- dice_ml/model_interfaces/pytorch_model.py | 5 +- dice_ml/utils/helpers.py | 16 ++--- .../utils/sample_architecture/vae_model.py | 4 +- ...ing_different_CF_explanation_methods.ipynb | 4 +- 22 files changed, 132 insertions(+), 149 deletions(-) diff --git a/dice_ml/__init__.py b/dice_ml/__init__.py index 63a3dc1d..cf41c603 100644 --- a/dice_ml/__init__.py +++ b/dice_ml/__init__.py @@ -1,6 +1,6 @@ from .data import Data -from .dice import Dice from .model import Model +from .dice import Dice __all__ = ["Data", "Model", diff --git a/dice_ml/counterfactual_explanations.py b/dice_ml/counterfactual_explanations.py index ba1a95f3..2e499784 100644 --- a/dice_ml/counterfactual_explanations.py +++ b/dice_ml/counterfactual_explanations.py @@ -1,12 +1,11 @@ import json -import os - import jsonschema +import os -from dice_ml.constants import _SchemaVersions -from dice_ml.diverse_counterfactuals import (CounterfactualExamples, - _DiverseCFV2SchemaConstants) +from dice_ml.diverse_counterfactuals import CounterfactualExamples from dice_ml.utils.exception import UserConfigValidationException +from dice_ml.diverse_counterfactuals import _DiverseCFV2SchemaConstants +from dice_ml.constants import _SchemaVersions class _CommonSchemaConstants: diff --git a/dice_ml/data_interfaces/private_data_interface.py b/dice_ml/data_interfaces/private_data_interface.py index 2960c17b..f319442d 100644 --- a/dice_ml/data_interfaces/private_data_interface.py +++ b/dice_ml/data_interfaces/private_data_interface.py @@ -1,11 +1,10 @@ """Module containing meta data information about private data.""" -import collections -import logging import sys - -import numpy as np import pandas as pd +import numpy as np +import collections +import logging from dice_ml.data_interfaces.base_data_interface import _BaseData diff --git a/dice_ml/data_interfaces/public_data_interface.py b/dice_ml/data_interfaces/public_data_interface.py index ab3e5ed1..31b8e6af 100644 --- a/dice_ml/data_interfaces/public_data_interface.py +++ b/dice_ml/data_interfaces/public_data_interface.py @@ -1,15 +1,13 @@ """Module containing all required information about the interface between raw (or transformed) public data and DiCE explainers.""" +import pandas as pd +import numpy as np import logging from collections import defaultdict -import numpy as np -import pandas as pd - from dice_ml.data_interfaces.base_data_interface import _BaseData -from dice_ml.utils.exception import (SystemException, - UserConfigValidationException) +from dice_ml.utils.exception import SystemException, UserConfigValidationException class PublicData(_BaseData): @@ -260,7 +258,7 @@ def get_valid_feature_range(self, feature_range_input, normalized=True): """ feature_range = {} - for _, feature_name in enumerate(self.feature_names): + for idx, feature_name in enumerate(self.feature_names): feature_range[feature_name] = [] if feature_name in self.continuous_feature_names: max_value = self.data_df[feature_name].max() diff --git a/dice_ml/dice.py b/dice_ml/dice.py index d1c78172..961b2859 100644 --- a/dice_ml/dice.py +++ b/dice_ml/dice.py @@ -3,9 +3,9 @@ such as RandomSampling, DiCEKD or DiCEGenetic""" from dice_ml.constants import BackEndTypes, SamplingStrategy -from dice_ml.data_interfaces.private_data_interface import PrivateData -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase from dice_ml.utils.exception import UserConfigValidationException +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml.data_interfaces.private_data_interface import PrivateData class Dice(ExplainerBase): @@ -67,14 +67,12 @@ def decide(model_interface, method): elif model_interface.backend == BackEndTypes.Tensorflow1: # pretrained Keras Sequential model with Tensorflow 1.x backend - from dice_ml.explainer_interfaces.dice_tensorflow1 import \ - DiceTensorFlow1 + from dice_ml.explainer_interfaces.dice_tensorflow1 import DiceTensorFlow1 return DiceTensorFlow1 elif model_interface.backend == BackEndTypes.Tensorflow2: # pretrained Keras Sequential model with Tensorflow 2.x backend - from dice_ml.explainer_interfaces.dice_tensorflow2 import \ - DiceTensorFlow2 + from dice_ml.explainer_interfaces.dice_tensorflow2 import DiceTensorFlow2 return DiceTensorFlow2 elif model_interface.backend == BackEndTypes.Pytorch: diff --git a/dice_ml/diverse_counterfactuals.py b/dice_ml/diverse_counterfactuals.py index 2dc5c044..fe8aa134 100644 --- a/dice_ml/diverse_counterfactuals.py +++ b/dice_ml/diverse_counterfactuals.py @@ -1,10 +1,8 @@ +import pandas as pd import copy import json - -import pandas as pd - -from dice_ml.constants import ModelTypes, _SchemaVersions from dice_ml.utils.serialize import DummyDataInterface +from dice_ml.constants import _SchemaVersions, ModelTypes class _DiverseCFV1SchemaConstants: @@ -117,7 +115,6 @@ def _visualize_internal(self, display_sparse_df=True, show_only_changes=False, def visualize_as_dataframe(self, display_sparse_df=True, show_only_changes=False): from IPython.display import display - # original instance print('Query instance (original outcome : %i)' % round(self.test_pred)) display(self.test_instance_df) # works only in Jupyter notebook diff --git a/dice_ml/explainer_interfaces/dice_KD.py b/dice_ml/explainer_interfaces/dice_KD.py index 532c340e..bac3bf6e 100644 --- a/dice_ml/explainer_interfaces/dice_KD.py +++ b/dice_ml/explainer_interfaces/dice_KD.py @@ -2,15 +2,14 @@ Module to generate counterfactual explanations from a KD-Tree This code is similar to 'Interpretable Counterfactual Explanations Guided by Prototypes': https://arxiv.org/pdf/1907.02584.pdf """ -import copy -import timeit - +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase import numpy as np +import timeit import pandas as pd +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceKD(ExplainerBase): @@ -260,10 +259,14 @@ def find_counterfactuals(self, data_df_copy, query_instance, query_instance_orig if total_cfs_found < total_CFs: self.elapsed = timeit.default_timer() - start_time m, s = divmod(self.elapsed, 60) - print('Only %d (required %d) ' % (total_cfs_found, self.total_CFs), + print('Only %d (required %d) ' % (total_cfs_found, total_CFs), 'Diverse Counterfactuals found for the given configuation, perhaps ', 'change the query instance or the features to vary...' '; total time taken: %02d' % m, 'min %02d' % s, 'sec') + elif total_cfs_found == 0: + print( + 'No Counterfactuals found for the given configuration, perhaps try with different parameters...', + '; total time taken: %02d' % m, 'min %02d' % s, 'sec') else: print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec') diff --git a/dice_ml/explainer_interfaces/dice_genetic.py b/dice_ml/explainer_interfaces/dice_genetic.py index ed35ee49..5928bbc2 100644 --- a/dice_ml/explainer_interfaces/dice_genetic.py +++ b/dice_ml/explainer_interfaces/dice_genetic.py @@ -2,17 +2,16 @@ Module to generate diverse counterfactual explanations based on genetic algorithm This code is similar to 'GeCo: Quality Counterfactual Explanations in Real Time': https://arxiv.org/pdf/2101.01292.pdf """ -import copy -import random -import timeit - +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase import numpy as np import pandas as pd +import random +import timeit +import copy from sklearn.preprocessing import LabelEncoder from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceGenetic(ExplainerBase): @@ -116,9 +115,8 @@ def do_random_init(self, num_inits, features_to_vary, query_instance, desired_cl def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desired_range): cfs = self.label_encode(cfs) cfs = cfs.reset_index(drop=True) - - self.cfs = np.zeros((self.population_size, self.data_interface.number_of_features)) - for kx in range(self.population_size): + row = [] + for kx in range(self.population_size*5): if kx >= len(cfs): break one_init = np.zeros(self.data_interface.number_of_features) @@ -143,16 +141,21 @@ def do_KD_init(self, features_to_vary, query_instance, cfs, desired_class, desir one_init[jx] = query_instance[jx] else: one_init[jx] = np.random.choice(self.feature_range[feature]) - self.cfs[kx] = one_init + t = tuple(one_init) + if t not in row: + row.append(t) + if len(row) == self.population_size: + break kx += 1 + self.cfs = np.array(row) - new_array = [tuple(row) for row in self.cfs] - uniques = np.unique(new_array, axis=0) - - if len(uniques) != self.population_size: + #if len(self.cfs) > self.population_size: + # pass + if len(self.cfs) != self.population_size: + print("Pericolo Loop infinito....!!!!") remaining_cfs = self.do_random_init( - self.population_size - len(uniques), features_to_vary, query_instance, desired_class, desired_range) - self.cfs = np.concatenate([uniques, remaining_cfs]) + self.population_size - len(self.cfs), features_to_vary, query_instance, desired_class, desired_range) + self.cfs = np.concatenate([self.cfs, remaining_cfs]) def do_cf_initializations(self, total_CFs, initialization, algorithm, features_to_vary, desired_range, desired_class, @@ -260,7 +263,7 @@ def _generate_counterfactuals(self, query_instance, total_CFs, initialization="k (see diverse_counterfactuals.py). """ - self.population_size = 10 * total_CFs + self.population_size = 3 * total_CFs self.start_time = timeit.default_timer() @@ -514,6 +517,9 @@ def find_counterfactuals(self, query_instance, desired_range, desired_class, if len(self.final_cfs) == self.total_CFs: print('Diverse Counterfactuals found! total time taken: %02d' % m, 'min %02d' % s, 'sec') + elif len(self.final_cfs) == 0: + print('No Counterfactuals found for the given configuration, perhaps try with different parameters...', + '; total time taken: %02d' % m, 'min %02d' % s, 'sec') else: print('Only %d (required %d) ' % (len(self.final_cfs), self.total_CFs), 'Diverse Counterfactuals found for the given configuation, perhaps ', diff --git a/dice_ml/explainer_interfaces/dice_pytorch.py b/dice_ml/explainer_interfaces/dice_pytorch.py index 09d257cc..0aaa52f3 100644 --- a/dice_ml/explainer_interfaces/dice_pytorch.py +++ b/dice_ml/explainer_interfaces/dice_pytorch.py @@ -1,16 +1,16 @@ """ Module to generate diverse counterfactual explanations based on PyTorch framework """ -import copy -import random -import timeit +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +import torch import numpy as np -import torch +import random +import timeit +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DicePyTorch(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_random.py b/dice_ml/explainer_interfaces/dice_random.py index 2995a398..b97c8fcf 100644 --- a/dice_ml/explainer_interfaces/dice_random.py +++ b/dice_ml/explainer_interfaces/dice_random.py @@ -3,15 +3,14 @@ Module to generate diverse counterfactual explanations based on random sampling. A simple implementation. """ -import random -import timeit - +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase import numpy as np import pandas as pd +import random +import timeit from dice_ml import diverse_counterfactuals as exp from dice_ml.constants import ModelTypes -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceRandom(ExplainerBase): @@ -109,11 +108,17 @@ class of query_instance for binary classification. cfs_df = None candidate_cfs = pd.DataFrame( np.repeat(query_instance.values, sample_size, axis=0), columns=query_instance.columns) - # Loop to change one feature at a time, then two features, and so on. + # Loop to change one feature at a time ##->(NOT TRUE), then two features, and so on. for num_features_to_vary in range(1, len(self.features_to_vary)+1): + # commented lines allow more values to change as num_features_to_vary increases, instead of .at you should use .loc + # is deliberately left commented out to let you choose. + # is slower, but more complete and still faster than genetic/KDtree + # selected_features = np.random.choice(self.features_to_vary, (sample_size, num_features_to_vary), replace=True) selected_features = np.random.choice(self.features_to_vary, (sample_size, 1), replace=True) for k in range(sample_size): - candidate_cfs.at[k, selected_features[k][0]] = random_instances.at[k, selected_features[k][0]] + candidate_cfs.at[k, selected_features[k][0]] = random_instances._get_value(k, selected_features[k][0]) + # If you only want to change one feature, you should use _get_value + # candidate_cfs.iloc[k][selected_features[k]]=random_instances.iloc[k][selected_features[k]] scores = self.predict_fn(candidate_cfs) validity = self.decide_cf_validity(scores) if sum(validity) > 0: diff --git a/dice_ml/explainer_interfaces/dice_tensorflow1.py b/dice_ml/explainer_interfaces/dice_tensorflow1.py index 69ee8298..8ad5088b 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow1.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow1.py @@ -1,17 +1,17 @@ """ Module to generate diverse counterfactual explanations based on tensorflow 1.x """ -import collections -import copy -import random -import timeit +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +import tensorflow as tf import numpy as np -import tensorflow as tf +import random +import collections +import timeit +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceTensorFlow1(ExplainerBase): diff --git a/dice_ml/explainer_interfaces/dice_tensorflow2.py b/dice_ml/explainer_interfaces/dice_tensorflow2.py index 58445929..b8e1ab75 100644 --- a/dice_ml/explainer_interfaces/dice_tensorflow2.py +++ b/dice_ml/explainer_interfaces/dice_tensorflow2.py @@ -1,16 +1,16 @@ """ Module to generate diverse counterfactual explanations based on tensorflow 2.x """ -import copy -import random -import timeit +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +import tensorflow as tf import numpy as np -import tensorflow as tf +import random +import timeit +import copy from dice_ml import diverse_counterfactuals as exp from dice_ml.counterfactual_explanations import CounterfactualExplanations -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase class DiceTensorFlow2(ExplainerBase): @@ -177,7 +177,7 @@ def do_cf_initializations(self, total_CFs, algorithm, features_to_vary): # CF initialization if len(self.cfs) != self.total_CFs: self.cfs = [] - for _ in range(self.total_CFs): + for ix in range(self.total_CFs): one_init = [[]] for jx in range(self.minx.shape[1]): one_init[0].append(np.random.uniform(self.minx[0][jx], self.maxx[0][jx])) diff --git a/dice_ml/explainer_interfaces/explainer_base.py b/dice_ml/explainer_interfaces/explainer_base.py index 5f25d1cc..d1c73f4e 100644 --- a/dice_ml/explainer_interfaces/explainer_base.py +++ b/dice_ml/explainer_interfaces/explainer_base.py @@ -2,17 +2,17 @@ Subclasses implement interfaces for different ML frameworks such as TensorFlow or PyTorch. All methods are in dice_ml.explainer_interfaces""" +import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable - import numpy as np import pandas as pd -from sklearn.neighbors import KDTree from tqdm import tqdm -from dice_ml.constants import ModelTypes +from collections.abc import Iterable +from sklearn.neighbors import KDTree from dice_ml.counterfactual_explanations import CounterfactualExplanations from dice_ml.utils.exception import UserConfigValidationException +from dice_ml.constants import ModelTypes class ExplainerBase(ABC): @@ -50,7 +50,7 @@ def generate_counterfactuals(self, query_instances, total_CFs, desired_class="opposite", desired_range=None, permitted_range=None, features_to_vary="all", stopping_threshold=0.5, posthoc_sparsity_param=0.1, - posthoc_sparsity_algorithm="linear", verbose=False, **kwargs): + posthoc_sparsity_algorithm=None, verbose=False, **kwargs): """General method for generating counterfactuals. :param query_instances: Input point(s) for which counterfactuals are to be generated. @@ -81,6 +81,16 @@ def generate_counterfactuals(self, query_instances, total_CFs, if total_CFs <= 0: raise UserConfigValidationException( "The number of counterfactuals generated per query instance (total_CFs) should be a positive integer.") + if total_CFs > 10: + if posthoc_sparsity_algorithm == None: + posthoc_sparsity_algorithm = 'binary' + elif total_CFs >50 and posthoc_sparsity_algorithm == 'linear': + warnings.warn("The number of counterfactuals (total_CFs={}) generated per query instance could take much time; " + "if too slow try to change the parameter 'posthoc_sparsity_algorithm' from 'linear' to " + "'binary' search!".format(total_CFs)) + elif posthoc_sparsity_algorithm == None: + posthoc_sparsity_algorithm = 'linear' + cf_examples_arr = [] query_instances_list = [] if isinstance(query_instances, pd.DataFrame): @@ -88,7 +98,6 @@ def generate_counterfactuals(self, query_instances, total_CFs, query_instances_list.append(query_instances[ix:(ix+1)]) elif isinstance(query_instances, Iterable): query_instances_list = query_instances - for query_instance in tqdm(query_instances_list): self.data_interface.set_continuous_feature_indexes(query_instance) res = self._generate_counterfactuals( @@ -103,9 +112,6 @@ def generate_counterfactuals(self, query_instances, total_CFs, verbose=verbose, **kwargs) cf_examples_arr.append(res) - - self._check_any_counterfactuals_computed(cf_examples_arr=cf_examples_arr) - return CounterfactualExplanations(cf_examples_list=cf_examples_arr) @abstractmethod @@ -211,12 +217,10 @@ def local_feature_importance(self, query_instances, cf_examples_list=None, if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10 to compute feature importance for all query points") + "greater than or equal to 10") elif total_CFs < 10: - raise UserConfigValidationException( - "The number of counterfactuals requested per " - "query instance should be greater than or equal to 10 " - "to compute feature importance for all query points") + raise UserConfigValidationException("The number of counterfactuals generated per " + "query instance should be greater than or equal to 10") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, @@ -257,25 +261,16 @@ def global_feature_importance(self, query_instances, cf_examples_list=None, input, and the global feature importance summarized over all inputs. """ if query_instances is not None and len(query_instances) < 10: - raise UserConfigValidationException( - "The number of query instances should be greater than or equal to 10 " - "to compute global feature importance over all query points") + raise UserConfigValidationException("The number of query instances should be greater than or equal to 10") if cf_examples_list is not None: - if len(cf_examples_list) < 10: - raise UserConfigValidationException( - "The number of points for which counterfactuals generated should be " - "greater than or equal to 10 " - "to compute global feature importance") - elif any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): + if any([len(cf_examples.final_cfs_df) < 10 for cf_examples in cf_examples_list]): raise UserConfigValidationException( "The number of counterfactuals generated per query instance should be " - "greater than or equal to 10 " - "to compute global feature importance over all query points") + "greater than or equal to 10") elif total_CFs < 10: raise UserConfigValidationException( - "The number of counterfactuals requested per query instance should be greater " - "than or equal to 10 " - "to compute global feature importance over all query points") + "The number of counterfactuals generated per query instance should be greater " + "than or equal to 10") importances = self.feature_importance( query_instances, cf_examples_list=cf_examples_list, @@ -354,7 +349,7 @@ def feature_importance(self, query_instances, cf_examples_list=None, continue per_query_point_cfs = 0 - for _, row in df.iterrows(): + for index, row in df.iterrows(): per_query_point_cfs += 1 for col in self.data_interface.continuous_feature_names: if not np.isclose(org_instance[col].iat[0], row[col]): @@ -535,7 +530,7 @@ def misc_init(self, stopping_threshold, desired_class, desired_range, test_pred) self.target_cf_class = np.array( [[self.infer_target_cfs_class(desired_class, test_pred, self.num_output_nodes)]], dtype=np.float32) - desired_class = int(self.target_cf_class[0][0]) + desired_class = self.target_cf_class[0][0] if self.target_cf_class == 0 and self.stopping_threshold > 0.5: self.stopping_threshold = 0.25 elif self.target_cf_class == 1 and self.stopping_threshold < 0.5: @@ -700,15 +695,3 @@ def round_to_precision(self): self.final_cfs_df[feature] = self.final_cfs_df[feature].astype(float).round(precisions[ix]) if self.final_cfs_df_sparse is not None: self.final_cfs_df_sparse[feature] = self.final_cfs_df_sparse[feature].astype(float).round(precisions[ix]) - - def _check_any_counterfactuals_computed(self, cf_examples_arr): - """Check if any counterfactuals were generated for any query point.""" - no_cf_generated = True - # Check if any counterfactuals were generated for any query point - for cf_examples in cf_examples_arr: - if cf_examples.final_cfs_df is not None and len(cf_examples.final_cfs_df) > 0: - no_cf_generated = False - break - if no_cf_generated: - raise UserConfigValidationException( - "No counterfactuals found for any of the query points! Kindly check your configuration.") diff --git a/dice_ml/explainer_interfaces/feasible_base_vae.py b/dice_ml/explainer_interfaces/feasible_base_vae.py index 28dc7d13..c59a26d4 100644 --- a/dice_ml/explainer_interfaces/feasible_base_vae.py +++ b/dice_ml/explainer_interfaces/feasible_base_vae.py @@ -1,15 +1,16 @@ # General Imports import numpy as np -# Pytorch -import torch -import torch.utils.data -from torch.nn import functional as F -from dice_ml import diverse_counterfactuals as exp # Dice Imports from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml import diverse_counterfactuals as exp from dice_ml.utils.helpers import get_base_gen_cf_initialization +# Pytorch +import torch +import torch.utils.data +from torch.nn import functional as F + class FeasibleBaseVAE(ExplainerBase): @@ -186,7 +187,7 @@ def generate_counterfactuals(self, query_instance, total_CFs, desired_class="opp curr_cf_pred = [] curr_test_pred = train_y.numpy() - for _ in range(total_CFs): + for cf_count in range(total_CFs): recon_err, kl_err, x_true, x_pred, cf_label = \ self.cf_vae.compute_elbo(train_x, 1.0-train_y, self.pred_model) while(cf_label == train_y): diff --git a/dice_ml/explainer_interfaces/feasible_model_approx.py b/dice_ml/explainer_interfaces/feasible_model_approx.py index a7fffda5..01fc609a 100644 --- a/dice_ml/explainer_interfaces/feasible_model_approx.py +++ b/dice_ml/explainer_interfaces/feasible_model_approx.py @@ -1,13 +1,13 @@ # Dice Imports +from dice_ml.explainer_interfaces.explainer_base import ExplainerBase +from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE +from dice_ml.utils.helpers import get_base_gen_cf_initialization + # Pytorch import torch import torch.utils.data from torch.nn import functional as F -from dice_ml.explainer_interfaces.explainer_base import ExplainerBase -from dice_ml.explainer_interfaces.feasible_base_vae import FeasibleBaseVAE -from dice_ml.utils.helpers import get_base_gen_cf_initialization - class FeasibleModelApprox(FeasibleBaseVAE, ExplainerBase): diff --git a/dice_ml/model.py b/dice_ml/model.py index 737701bf..0518cdf8 100644 --- a/dice_ml/model.py +++ b/dice_ml/model.py @@ -4,7 +4,6 @@ frameworks such as Tensorflow or PyTorch. """ import warnings - from dice_ml.constants import BackEndTypes, ModelTypes from dice_ml.utils.exception import UserConfigValidationException @@ -70,8 +69,7 @@ def decide(backend): import tensorflow # noqa: F401 except ImportError: raise UserConfigValidationException("Unable to import tensorflow. Please install tensorflow") - from dice_ml.model_interfaces.keras_tensorflow_model import \ - KerasTensorFlowModel + from dice_ml.model_interfaces.keras_tensorflow_model import KerasTensorFlowModel return KerasTensorFlowModel elif backend == BackEndTypes.Pytorch: diff --git a/dice_ml/model_interfaces/base_model.py b/dice_ml/model_interfaces/base_model.py index 3a25b5cf..09b49ddd 100644 --- a/dice_ml/model_interfaces/base_model.py +++ b/dice_ml/model_interfaces/base_model.py @@ -3,12 +3,10 @@ All model interface methods are in dice_ml.model_interfaces""" import pickle - import numpy as np - +from dice_ml.utils.helpers import DataTransfomer from dice_ml.constants import ModelTypes from dice_ml.utils.exception import SystemException -from dice_ml.utils.helpers import DataTransfomer class BaseModel: @@ -64,7 +62,7 @@ def get_num_output_nodes(self, inp_size): temp_input = np.transpose(np.array([np.random.uniform(0, 1) for i in range(inp_size)]).reshape(-1, 1)) return self.get_output(temp_input).shape[1] - def get_num_output_nodes2(self, input_instance): + def get_num_output_nodes2(self, input): if self.model_type == ModelTypes.Regressor: raise SystemException('Number of output nodes not supported for regression') - return self.get_output(input_instance).shape[1] + return self.get_output(input).shape[1] diff --git a/dice_ml/model_interfaces/keras_tensorflow_model.py b/dice_ml/model_interfaces/keras_tensorflow_model.py index df150850..72619f12 100644 --- a/dice_ml/model_interfaces/keras_tensorflow_model.py +++ b/dice_ml/model_interfaces/keras_tensorflow_model.py @@ -1,10 +1,9 @@ """Module containing an interface to trained Keras Tensorflow model.""" +from dice_ml.model_interfaces.base_model import BaseModel import tensorflow as tf from tensorflow import keras -from dice_ml.model_interfaces.base_model import BaseModel - class KerasTensorFlowModel(BaseModel): @@ -40,7 +39,7 @@ def get_output(self, input_tensor, training=False, transform_data=False): else: return self.model(input_tensor) - def get_gradient(self, input_instance): + def get_gradient(self, input): # Future Support raise NotImplementedError("Future Support") diff --git a/dice_ml/model_interfaces/pytorch_model.py b/dice_ml/model_interfaces/pytorch_model.py index 9cf577cc..e0063f5a 100644 --- a/dice_ml/model_interfaces/pytorch_model.py +++ b/dice_ml/model_interfaces/pytorch_model.py @@ -1,8 +1,7 @@ """Module containing an interface to trained PyTorch model.""" -import torch - from dice_ml.model_interfaces.base_model import BaseModel +import torch class PyTorchModel(BaseModel): @@ -38,7 +37,7 @@ def get_output(self, input_tensor, transform_data=False): def set_eval_mode(self): self.model.eval() - def get_gradient(self, input_instance): + def get_gradient(self, input): # Future Support raise NotImplementedError("Future Support") diff --git a/dice_ml/utils/helpers.py b/dice_ml/utils/helpers.py index e3ea688f..866662c3 100644 --- a/dice_ml/utils/helpers.py +++ b/dice_ml/utils/helpers.py @@ -1,17 +1,17 @@ """ This module containts helper functions to load data and get meta deta. """ -import os -import shutil - import numpy as np import pandas as pd -from sklearn.model_selection import train_test_split -# for data transformations -from sklearn.preprocessing import FunctionTransformer +import shutil +import os import dice_ml +# for data transformations +from sklearn.preprocessing import FunctionTransformer +from sklearn.model_selection import train_test_split + def load_adult_income_dataset(only_train=True): """Loads adult income dataset from https://archive.ics.uci.edu/ml/datasets/Adult and prepares @@ -168,11 +168,11 @@ def get_base_gen_cf_initialization(data_interface, encoded_size, cont_minx, cont wm1, wm2, wm3, learning_rate): # Dice Imports - TODO: keep this method for VAE as a spearate module or move it to feasible_base_vae.py. # Check dependencies. + from dice_ml.utils.sample_architecture.vae_model import CF_VAE + # Pytorch from torch import optim - from dice_ml.utils.sample_architecture.vae_model import CF_VAE - # Dataset for training Variational Encoder Decoder model for CF Generation df = data_interface.normalize_data(data_interface.one_hot_encoded_data) encoded_data = df[data_interface.ohe_encoded_feature_names + [data_interface.outcome_name]] diff --git a/dice_ml/utils/sample_architecture/vae_model.py b/dice_ml/utils/sample_architecture/vae_model.py index 3b4be568..6d461494 100644 --- a/dice_ml/utils/sample_architecture/vae_model.py +++ b/dice_ml/utils/sample_architecture/vae_model.py @@ -109,7 +109,7 @@ def forward(self, x, c): res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples - for _ in range(mc_samples): + for i in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(torch.cat((z, c), 1)) res['z'].append(z) @@ -239,7 +239,7 @@ def forward(self, x): res['z'] = [] res['x_pred'] = [] res['mc_samples'] = mc_samples - for _ in range(mc_samples): + for i in range(mc_samples): z = self.sample_latent_code(em, ev) x_pred = self.decoder(z) res['z'].append(z) diff --git a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb index 9786d701..fec9ce5f 100644 --- a/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb +++ b/docs/source/notebooks/Benchmarking_different_CF_explanation_methods.ipynb @@ -274,7 +274,7 @@ " for q in query_instances:\n", " if q in d.categorical_feature_names:\n", " query_instances.loc[:, q] = \\\n", - " [random.choice(dataset[q].values.unique()) for _ in query_instances.index]\n", + " [random.choice(dataset[q].unique()) for _ in query_instances.index]\n", " else:\n", " query_instances.loc[:, q] = \\\n", " [np.random.uniform(dataset[q].min(), dataset[q].max()) for _ in query_instances.index]\n", @@ -329,4 +329,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file