diff --git a/captum/attr/_core/lime.py b/captum/attr/_core/lime.py index b568c5cd0a..5099eaad5a 100644 --- a/captum/attr/_core/lime.py +++ b/captum/attr/_core/lime.py @@ -394,8 +394,12 @@ def attribute( >>> # input will be different and may have a smaller feature set, so >>> # an appropriate transformation function should be provided. >>> - >>> # Generating random input with size 2 x 5 - >>> input = torch.randn(2, 5) + >>> def to_interp_transform(curr_sample, original_inp, + >>> **kwargs): + >>> return curr_sample + >>> + >>> # Generating random input with size 1 x 5 + >>> input = torch.randn(1, 5) >>> # Defining LimeBase interpreter >>> lime_attr = LimeBase(net, SkLearnLinearModel("linear_model.Ridge"), @@ -403,7 +407,7 @@ def attribute( perturb_func=perturb_func, perturb_interpretable_space=False, from_interp_rep_transform=None, - to_interp_rep_transform=lambda x: x) + to_interp_rep_transform=to_interp_transform) >>> # Computes interpretable model, returning coefficients of linear >>> # model. >>> attr_coefs = lime_attr.attribute(input, target=1, kernel_width=1.1)